from sqlapi import sql from metasqlobject.boundattributes import BoundAttribute from sqlobject.classreg import class_registry import events class ManyToMany(BoundAttribute): intermediate_table = None other_intermediate_column = None bound_intermediate_column = None other_column = 'id' bound_column = 'id' def __init__(self, other_class=None, **attrs): if other_class is not None: attrs['other_class'] = other_class BoundAttribute.__init__(self, **attrs) if (self.bound_class is not None and self.other_class is not None): if isinstance(self.other_class, basestring): class_registry.register_callback( self._set_other_class, self.other_class, self.bound_class.__module__) else: self._set_other_class(self.other_class) self.bound_class.event_hub.listen( events.TableSQLSignal, self, 'create_sql') def _set_other_class(self, other_class): self.other_class = other_class other_class.event_hub.listen( events.TableSQLSignal, self, 'create_sql') self.clause = ( (getattr(self.other_class, self.other_column) == self.other_intermediate_column) & (getattr(self.bound_class, self.bound_column) == self.bound_intermediate_column)) self._setup_attributes() def _setup_attributes(self): """ This is called when .other_class has been bound, and so we can set up all the basic attributes """ if self.intermediate_table is not None: if isinstance(self.intermediate_table, basestring): self.intermediate_table = sql.Table(self.intermediate_table) else: names = [self.other_class.sqlmeta.table, self.bound_class.sqlmeta.table] names.sort() self.intermediate_table = sql.Table( '%s_%s' % (names[0], names[1])) table = self.intermediate_table if self.other_intermediate_column is None: # @@: Style: # Or base off the primary key? self.other_intermediate_column = table[self.other_class.sqlmeta.table + '_id'] elif isinstance(self.other_intermediate_column, basestring): self.other_intermediate_column = table[self.other_intermediate_column] if self.bound_intermediate_column is None: # @@: Style: # Or base off the primary key? self.bound_intermediate_column = table[self.bound_class.sqlmeta.table + '_id'] elif isinstance(self.bound_intermediate_column, basestring): self.bound_intermediate_column = table[self.bound_intermediate_column] def __get__(self, obj, type): if obj is None: return self query = ( (self.other_class.id == self.other_intermediate_column) & (obj.id == self.bound_intermediate_column)) select = self.other_class.select(query) return _ManyToManySelectWrapper(obj, self, select) def _create_sql_for_column(self, name, column, connection): return column.create_sql( connection, name, primary_key=False) def create_sql(self, soclass, connection, create_sql, lazy_sql, create_context): intermediate_table = self.intermediate_table.table_name if (intermediate_table, 'created') in create_context: return create_context[(intermediate_table, 'created')] = None cur = connection.cursor() if connection.plugin.table_exists( intermediate_table, cur): return # @@: These need to deal with compound IDs: create_cols = [] lazy_cols = [] for soclass, this_col in [ (self.bound_class, self.bound_intermediate_column), (self.other_class, self.other_intermediate_column)]: col_sql, col_lazy = soclass.sqlmeta.columns['id'].create_sql( connection, this_col.column_name, primary_key=False) create_cols.append(col_sql) lazy_cols.extend(lazy_sql) create_table = sql.CreateTable( intermediate_table, create_cols) lazy_sql.append(create_table) lazy_sql.extend(lazy_cols) class _ManyToManySelectWrapper(object): def __init__(self, for_object, join, select): self.for_object = for_object self.join = join self.select = select def __getattr__(self, attr): # @@: This passes through private variable access # too... should it? Also magic methods, like __str__ return getattr(self.select, attr) def __repr__(self): return '<%s for: %s>' % (self.__class__.__name__, repr(self.select)) def __str__(self): return str(self.select) def __iter__(self): return iter(self.select) def __getitem__(self, key): return self.select[key] def add(self, obj): s = sql.Insert(self.join.intermediate_table.table_name, {self.join.bound_intermediate_column.column_name: self.for_object.id, self.join.other_intermediate_column.column_name: obj.id}) self.connection.execute(s) def remove(self, obj): s = sql.Delete( self.join.intermediate_table.table_name, (self.join.bound_intermediate_column == self.for_object.id) & (self.join.other_intermediate_column == obj.id)) self.connection.execute(s) def create(self, **kw): obj = self.join.other_class(**kw) self.add(obj) return obj class OneToMany(BoundAttribute): join_column = None default_join_column = None joined_to_column = 'id' def __init__(self, other_class=None, **attrs): if other_class is not None: attrs['other_class'] = other_class BoundAttribute.__init__(self, **attrs) if self.other_class and self.bound_class: if (not self.join_column or self.join_column == self.default_join_column): self.join_column = self.bound_class.sqlmeta.table + '_id' # This lets us know that we can reset this value # if we need to: self.default_join_column = self.join_column if isinstance(self.other_class, basestring): class_registry.register_callback( self._set_other_class, self.other_class, self.bound_class.__module__) else: self._set_other_class(self.other_class) def _set_other_class(self, other_class): self.other_class = other_class self.clause = ( getattr(self.other_class, self.join_column) == getattr(self.bound_class, self.joined_to_column)) def __get__(self, obj, type): if obj is None: return self query = ( getattr(self.other_class, self.join_column) == getattr(obj, self.joined_to_column)) select = self.other_class.select(query) return _OneToManySelectWrapper(obj, self, select) class _OneToManySelectWrapper(object): def __init__(self, for_object, join, select): self.for_object = for_object self.join = join self.select = select def __getattr__(self, attr): # @@: This passes through private variable access too... should it? # Also magic methods, like __str__ return getattr(self.select, attr) def __repr__(self): return '<%s for: %s>' % (self.__class__.__name__, repr(self.select)) def __str__(self): return str(self.select) def __iter__(self): return iter(self.select) def __getitem__(self, key): return self.select[key] def create(self, **kw): pass