from sqlapi import sql __all__ = ['SelectResults'] class SelectResults(object): def __init__(self, column_specs, clause, return_single=True, connection=None, primary_class=None, order_by=None, reversed=False, distinct=False): self.column_specs = column_specs self.clause = clause self.connection = connection self.primary_class = primary_class if return_single: assert len(column_specs) == 1, ( "You can only use return_single when a single spec " "is passed in (got %r)" % column_specs) self.return_single = return_single self.columns = [] for cols, builder in self.column_specs: self.columns.extend(cols) self._order_by = self._convert_order_by(order_by, primary_class) if reversed: self._order_by = self._reverse(self._order_by) self._distinct = distinct self.sql_select = sql.Select( self.columns, clause, order_by=self._order_by, distinct=self._distinct) def _convert_order_by(self, order_by, order_by_class): if order_by is None: return None if not isinstance(order_by, (list, tuple)): order_by = [order_by] result = [] for item in order_by: if isinstance(item, basestring): desc = False if item.startswith('-'): desc = True item = item[1:] item = self._get_column(order_by_class, item) if desc: item = sql.Desc(item) result.append(item) return result def _get_column(self, cls, name): try: return getattr(cls, name) except AttributeError: # We'll assume we got a db-column name, not an attribute # name for col in cls.sqlmeta.columns.values(): if col.db_name == name: return col raise AttributeError('No column by the name %r' % name) def _reverse(self, order_by): if order_by is None: return None result = [] for item in order_by: if isinstance(item, sql.Desc): result.append(item.expr) else: result.append(sql.Desc(item)) return result ######################################## ## Modifiers ######################################## def order_by(self, new_order_by): return self.clone(order_by=new_order_by) def distinct(self, distinct): return self.clone(distinct=distinct) def clone(self, **kw): for attr, keyword in [('column_specs', None), ('clause', None), ('return_single', None), ('connection', None), ('primary_class', None), ('distinct', '_distinct'), ('order_by', '_order_by')]: kw.setdefault(attr, getattr(self, keyword or attr)) return self.__class__(**kw) def __iter__(self): rows = self.do_query() for row in rows: yield self.produce_from_row(row) def do_query(self): cur = self.connection.cursor() cur.execute(self.sql_select) rows = cur.fetchall() cur.close() return rows def produce_from_row(self, row): if self.return_single: return self.produce_spec_from_row( self.column_specs[0], row)[0] else: results = [] rest = row for spec in self.column_specs: result, rest = self.produce_spec_from_row( spec, rest) results.append(result) return results def produce_spec_from_row(self, spec, row): columns, builder = spec if len(columns) < len(row): raise AssertionError( "Need %i columns to fill %r; only %i left" % (len(columns), builder, len(row))) value = builder(row[:len(columns)]) return value, row[len(columns):] def __getitem__(self, item): # @@: This should be lazy return list(self)[item] def accumulate_many(self, *exprs): # @@: Also accumulateMany -- has different return values query = sql.Select(exprs, self.clause) cur = self.connection.cursor() cur.execute(query) values = cur.fetchone() cur.close() return values def accumulate(self, expr): return self.accumulate_many(expr)[0] def count(self): COUNT = sql.funcs.COUNT star = sql.star_from([c for c, b in self.column_specs]) return self.accumulate(COUNT(star)) def min(self, expr): return self.accumulate(sql.funcs.MIN(expr)) def max(self, expr): return self.accumulate(sql.funcs.MAX(expr)) def avg(self, expr): return self.accumulate(sql.funcs.AVG(expr)) def sum(self, expr): return self.accumulate(sql.funcs.SUM(expr))