""" Main (um, only) unit testing for SQLObject. Use -vv to see SQL queries, -vvv to also see output from queries, and together with --inserts to see the SQL from the standard insert statements (which are often boring). """ from __future__ import generators import sys if '--coverage' in sys.argv: import coverage print 'Starting coverage' coverage.erase() coverage.start() from SQLObjectTest import * from sqlobject import * from sqlobject.include import validators from sqlobject import classregistry from mx import DateTime global curr_db curr_db = None from sqlobject import cache ######################################## ## Basic operation ######################################## class TestSO1(SQLObject): name = StringCol(length=50, dbName='name_col') _cacheValues = False _columns = [ StringCol('passwd', length=10), ] def _set_passwd(self, passwd): self._SO_set_passwd(passwd.encode('rot13')) class TestCase1(SQLObjectTest): classes = [TestSO1] MyClass = TestSO1 info = [('bob', 'god'), ('sally', 'sordid'), ('dave', 'dremel'), ('fred', 'forgo')] def inserts(self): for name, passwd in self.info: self.MyClass(name=name, passwd=passwd) def testGet(self): bob = self.MyClass.selectBy(name='bob')[0] self.assertEqual(bob.name, 'bob') self.assertEqual(bob.passwd, 'god'.encode('rot13')) def testNewline(self): bob = self.MyClass.selectBy(name='bob')[0] testString = 'hey\nyou\\can\'t you see me?\t' bob.name = testString self.failUnless(bob.name == testString, (bob.name, testString)) class TestCaseGetSet(TestCase1): def testGet(self): bob = TestSO1.selectBy(name='bob')[0] self.assertEqual(bob.name, 'bob') bob.name = 'joe' self.assertEqual(bob.name, 'joe') class TestSO2(SQLObject): name = StringCol(length=50, dbName='name_col') passwd = StringCol(length=10) def _set_passwd(self, passwd): self._SO_set_passwd(passwd.encode('rot13')) class TestCase2(TestCase1): classes = [TestSO2] MyClass = TestSO2 class TestSO3(SQLObject): name = StringCol(length=10, dbName='name_col') other = ForeignKey('TestSO4', default=None) other2 = KeyCol(foreignKey='TestSO4', default=None) class TestSO4(SQLObject): me = StringCol(length=10) class Student(SQLObject): is_smart = BoolCol() class BoolColTest(SQLObjectTest): classes = [Student] def testBoolCol(self): student = Student(is_smart=False) self.assertEqual(student.is_smart, False) student2 = Student(is_smart='false') self.assertEqual(student2.is_smart, True) class TestCase34(SQLObjectTest): classes = [TestSO4, TestSO3] def testForeignKey(self): tc3 = TestSO3(name='a') self.assertEqual(tc3.other, None) self.assertEqual(tc3.other2, None) self.assertEqual(tc3.otherID, None) self.assertEqual(tc3.other2ID, None) tc4a = TestSO4(me='1') tc3.other = tc4a self.assertEqual(tc3.other, tc4a) self.assertEqual(tc3.otherID, tc4a.id) tc4b = TestSO4(me='2') tc3.other = tc4b.id self.assertEqual(tc3.other, tc4b) self.assertEqual(tc3.otherID, tc4b.id) tc4c = TestSO4(me='3') tc3.other2 = tc4c self.assertEqual(tc3.other2, tc4c) self.assertEqual(tc3.other2ID, tc4c.id) tc4d = TestSO4(me='4') tc3.other2 = tc4d.id self.assertEqual(tc3.other2, tc4d) self.assertEqual(tc3.other2ID, tc4d.id) tcc = TestSO3(name='b', other=tc4a) self.assertEqual(tcc.other, tc4a) tcc2 = TestSO3(name='c', other=tc4a.id) self.assertEqual(tcc2.other, tc4a) class TestSO5(SQLObject): name = StringCol(length=10, dbName='name_col') other = ForeignKey('TestSO6', default=None, cascade=True) another = ForeignKey('TestSO7', default=None, cascade=True) class TestSO6(SQLObject): name = StringCol(length=10, dbName='name_col') other = ForeignKey('TestSO7', default=None, cascade=True) class TestSO7(SQLObject): name = StringCol(length=10, dbName='name_col') class TestCase567(SQLObjectTest): classes = [TestSO7, TestSO6, TestSO5] def testForeignKeyDestroySelfCascade(self): tc5 = TestSO5(name='a') tc6a = TestSO6(name='1') tc5.other = tc6a tc7a = TestSO7(name='2') tc6a.other = tc7a tc5.another = tc7a self.assertEqual(tc5.other, tc6a) self.assertEqual(tc5.otherID, tc6a.id) self.assertEqual(tc6a.other, tc7a) self.assertEqual(tc6a.otherID, tc7a.id) self.assertEqual(tc5.other.other, tc7a) self.assertEqual(tc5.other.otherID, tc7a.id) self.assertEqual(tc5.another, tc7a) self.assertEqual(tc5.anotherID, tc7a.id) self.assertEqual(tc5.other.other, tc5.another) self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 1) self.assertEqual(TestSO7.select().count(), 1) tc6b = TestSO6(name='3') tc6c = TestSO6(name='4') tc7b = TestSO7(name='5') tc6b.other = tc7b tc6c.other = tc7b self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 3) self.assertEqual(TestSO7.select().count(), 2) tc6b.destroySelf() self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 2) self.assertEqual(TestSO7.select().count(), 2) tc7b.destroySelf() self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 1) self.assertEqual(TestSO7.select().count(), 1) tc7a.destroySelf() self.assertEqual(TestSO5.select().count(), 0) self.assertEqual(TestSO6.select().count(), 0) self.assertEqual(TestSO7.select().count(), 0) def testForeignKeyDropTableCascade(self): if curr_db == 'sybase': # XXX This test doesn't pass with sybase. return tc5a = TestSO5(name='a') tc6a = TestSO6(name='1') tc5a.other = tc6a tc7a = TestSO7(name='2') tc6a.other = tc7a tc5a.another = tc7a tc5b = TestSO5(name='b') tc5c = TestSO5(name='c') tc6b = TestSO6(name='3') tc5c.other = tc6b self.assertEqual(TestSO5.select().count(), 3) self.assertEqual(TestSO6.select().count(), 2) self.assertEqual(TestSO7.select().count(), 1) TestSO7.dropTable(cascade=True) self.assertEqual(TestSO5.select().count(), 3) self.assertEqual(TestSO6.select().count(), 2) tc6a.destroySelf() self.assertEqual(TestSO5.select().count(), 2) self.assertEqual(TestSO6.select().count(), 1) tc6b.destroySelf() self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 0) self.assertEqual(iter(TestSO5.select()).next(), tc5b) tc6c = TestSO6(name='3') tc5b.other = tc6c self.assertEqual(TestSO5.select().count(), 1) self.assertEqual(TestSO6.select().count(), 1) tc6c.destroySelf() self.assertEqual(TestSO5.select().count(), 0) self.assertEqual(TestSO6.select().count(), 0) class TestSO8(SQLObject): name = StringCol(length=10, dbName='name_col') other = ForeignKey('TestSO9', default=None, cascade=False) class TestSO9(SQLObject): name = StringCol(length=10, dbName='name_col') class TestCase89(SQLObjectTest): classes = [TestSO9, TestSO8] def testForeignKeyDestroySelfRestrict(self): tc8a = TestSO8(name='a') tc9a = TestSO9(name='1') tc8a.other = tc9a tc8b = TestSO8(name='b') tc9b = TestSO9(name='2') self.assertEqual(tc8a.other, tc9a) self.assertEqual(tc8a.otherID, tc9a.id) self.assertEqual(TestSO8.select().count(), 2) self.assertEqual(TestSO9.select().count(), 2) self.assertRaises(Exception, tc9a.destroySelf) tc9b.destroySelf() self.assertEqual(TestSO8.select().count(), 2) self.assertEqual(TestSO9.select().count(), 1) tc8a.destroySelf() tc8b.destroySelf() tc9a.destroySelf() self.assertEqual(TestSO8.select().count(), 0) self.assertEqual(TestSO9.select().count(), 0) ######################################## ## Fancy sort ######################################## class Names(SQLObject): _table = 'names_table' firstName = StringCol(length=30) lastName = StringCol(length=30) _defaultOrder = ['lastName', 'firstName'] class NamesTest(SQLObjectTest): classes = [Names] def inserts(self): for firstName, lastName in [('aj', 'baker'), ('joe', 'robbins'), ('tim', 'jackson'), ('joe', 'baker'), ('zoe', 'robbins')]: Names(firstName=firstName, lastName=lastName) def testDefaultOrder(self): self.assertEqual([(n.firstName, n.lastName) for n in Names.select()], [('aj', 'baker'), ('joe', 'baker'), ('tim', 'jackson'), ('joe', 'robbins'), ('zoe', 'robbins')]) def testOtherOrder(self): self.assertEqual([(n.firstName, n.lastName) for n in Names.select().orderBy(['firstName', 'lastName'])], [('aj', 'baker'), ('joe', 'baker'), ('joe', 'robbins'), ('tim', 'jackson'), ('zoe', 'robbins')]) def testUntranslatedColumnOrder(self): self.assertEqual([(n.firstName, n.lastName) for n in Names.select().orderBy(['first_name', 'last_name'])], [('aj', 'baker'), ('joe', 'baker'), ('joe', 'robbins'), ('tim', 'jackson'), ('zoe', 'robbins')]) def testSingleUntranslatedColumnOrder(self): self.assertEqual([n.firstName for n in Names.select().orderBy('firstName')], ['aj', 'joe', 'joe', 'tim', 'zoe']) self.assertEqual([n.firstName for n in Names.select().orderBy('first_name')], ['aj', 'joe', 'joe', 'tim', 'zoe']) self.assertEqual([n.firstName for n in Names.select().orderBy('-firstName')], ['zoe', 'tim', 'joe', 'joe', 'aj']) self.assertEqual([n.firstName for n in Names.select().orderBy('-first_name')], ['zoe', 'tim', 'joe', 'joe', 'aj']) self.assertEqual([n.firstName for n in Names.select().orderBy(Names.q.firstName)], ['aj', 'joe', 'joe', 'tim', 'zoe']) ######################################## ## Select results ######################################## class IterTest(SQLObject): name = StringCol(dbName='name_col') class IterationTestCase(SQLObjectTest): '''Test basic iteration techniques''' classes = [IterTest] names = ('a', 'b', 'c') def inserts(self): for name in self.names: IterTest(name=name) def test_00_normal(self): count = 0 for test in IterTest.select(): count += 1 self.failIf(count != len(self.names)) def test_01_turn_to_list(self): count = 0 for test in list(IterTest.select()): count += 1 self.failIf(count != len(self.names)) def test_02_generator(self): def enumerate(iterable): i = 0 for obj in iterable: yield i, obj i += 1 all = IterTest.select() count = 0 for i, test in enumerate(all): count += 1 self.failIf(count != len(self.names)) def test_03_ranged_indexed(self): all = IterTest.select() count = 0 for i in range(all.count()): test = all[i] count += 1 self.failIf(count != len(self.names)) def test_04_indexed_ended_by_exception(self): all = IterTest.select() count = 0 try: while 1: test = all[count] count = count+1 # Stop the test if it's gone on too long if count > len(self.names): break except IndexError: pass self.assertEqual(count, len(self.names)) ######################################## ## Delete during select ######################################## class DeleteSelectTest(TestCase1): def testGet(self): return def testSelect(self): for obj in TestSO1.select('all'): obj.destroySelf() self.assertEqual(list(TestSO1.select('all')), []) ######################################## ## Delete without caching ######################################## class NoCache(SQLObject): name = StringCol() class TestNoCache(SQLObjectTest): classes=[NoCache] def setUp(self): SQLObjectTest.setUp(self) NoCache._connection.cache = cache.CacheSet(cache=False) def tearDown(self): NoCache._connection.cache = cache.CacheSet(cache=True) SQLObjectTest.tearDown(self) def testDestroySelf(self): value = NoCache(name='test') value.destroySelf() ######################################## ## Transaction test ######################################## class TestSOTrans(SQLObject): #_cacheValues = False name = StringCol(length=10, alternateID=True, dbName='name_col') _defaultOrderBy = 'name' class TransactionTest(SQLObjectTest): classes = [TestSOTrans] def inserts(self): TestSOTrans(name='bob') TestSOTrans(name='tim') def testTransaction(self): if not self.supportTransactions: return trans = TestSOTrans._connection.transaction() try: TestSOTrans._connection.autoCommit = 'exception' TestSOTrans(name='joe', connection=trans) trans.rollback() trans.begin() self.assertEqual([n.name for n in TestSOTrans.select(connection=trans)], ['bob', 'tim']) b = TestSOTrans.byName('bob', connection=trans) b.name = 'robert' trans.commit() self.assertEqual(b.name, 'robert') b.name = 'bob' trans.rollback() trans.begin() self.assertEqual(b.name, 'robert') finally: TestSOTrans._connection.autoCommit = True ######################################## ## Enum test ######################################## class Enum1(SQLObject): _columns = [ EnumCol('l', enumValues=['a', 'bcd', 'e']), ] class TestEnum1(SQLObjectTest): classes = [Enum1] def inserts(self): for l in ['a', 'bcd', 'a', 'e']: Enum1(l=l) def testBad(self): if self.supportRestrictedEnum: try: v = Enum1(l='b') except Exception, e: pass else: print v assert 0, "This should cause an error" ######################################## ## Slicing tests ######################################## class Counter(SQLObject): _columns = [ IntCol('number', notNull=True), ] class SliceTest(SQLObjectTest): classes = [Counter] def inserts(self): for i in range(100): Counter(number=i) def counterEqual(self, counters, value): self.assertEquals([c.number for c in counters], value) def test1(self): self.counterEqual(Counter.select('all', orderBy='number'), range(100)) def test2(self): self.counterEqual(Counter.select('all', orderBy='number')[10:20], range(10, 20)) def test3(self): self.counterEqual(Counter.select('all', orderBy='number')[20:30][:5], range(20, 25)) def test4(self): self.counterEqual(Counter.select('all', orderBy='number')[:-10], range(0, 90)) def test5(self): self.counterEqual(Counter.select('all', orderBy='number', reversed=True), range(99, -1, -1)) def test6(self): self.counterEqual(Counter.select('all', orderBy='-number'), range(99, -1, -1)) ######################################## ## Select tests ######################################## class Counter2(SQLObject): _columns = [ IntCol('n1', notNull=True), IntCol('n2', notNull=True), ] class SelectTest(SQLObjectTest): classes = [Counter2] def inserts(self): for i in range(10): for j in range(10): Counter2(n1=i, n2=j) def counterEqual(self, counters, value): self.assertEquals([(c.n1, c.n2) for c in counters], value) def accumulateEqual(self, func, counters, value): self.assertEqual(func([ c.n1 for c in counters]), value) def test1(self): self.accumulateEqual(sum,Counter2.select(orderBy='n1'), sum(range(10)) * 10) def test2(self): self.accumulateEqual(len,Counter2.select('all'), 100) ######################################## ## Dynamic column tests ######################################## class Person(SQLObject): _columns = [StringCol('name', length=100, dbName='name_col')] _defaultOrder = 'name' class Phone(SQLObject): _columns = [StringCol('phone', length=12)] _defaultOrder = 'phone' class PeopleTest(SQLObjectTest): classes = [Person, Phone] def inserts(self): for n in ['jane', 'tim', 'bob', 'jake']: Person(name=n) for p in ['555-555-5555', '555-394-2930', '444-382-4854']: Phone(phone=p) def testDefaultOrder(self): self.assertEqual(list(Person.select('all')), list(Person.select('all', orderBy=Person._defaultOrder))) def testDynamicColumn(self): if not self.supportDynamic: return nickname = StringCol('nickname', length=10) Person.addColumn(nickname, changeSchema=True) n = Person(name='robert', nickname='bob') self.assertEqual([p.name for p in Person.select('all')], ['bob', 'jake', 'jane', 'robert', 'tim']) Person.delColumn(nickname, changeSchema=True) def testDynamicJoin(self): if not self.supportDynamic: return col = KeyCol('personID', foreignKey='Person') Phone.addColumn(col, changeSchema=True) join = MultipleJoin('Phone') Person.addJoin(join) for phone in Phone.select('all'): if phone.phone.startswith('555'): phone.person = Person.selectBy(name='tim')[0] else: phone.person = Person.selectBy(name='bob')[0] l = [p.phone for p in Person.selectBy(name='tim')[0].phones] l.sort() self.assertEqual(l, ['555-394-2930', '555-555-5555']) Phone.delColumn(col, changeSchema=True) Person.delJoin(join) ######################################## ## Auto class generation ######################################## class AutoTest(SQLObjectTest): mysqlCreate = """ CREATE TABLE IF NOT EXISTS auto_test ( auto_id INT AUTO_INCREMENT PRIMARY KEY, first_name VARCHAR(100), last_name VARCHAR(200) NOT NULL, age INT DEFAULT NULL, created DATETIME NOT NULL, happy char(1) DEFAULT 'Y' NOT NULL, wannahavefun TINYINT DEFAULT 0 NOT NULL ) """ postgresCreate = """ CREATE TABLE auto_test ( auto_id SERIAL PRIMARY KEY, first_name VARCHAR(100), last_name VARCHAR(200) NOT NULL, age INT DEFAULT 0, created VARCHAR(40) NOT NULL, happy char(1) DEFAULT 'Y' NOT NULL, wannahavefun BOOL DEFAULT FALSE NOT NULL ) """ sybaseCreate = """ CREATE TABLE auto_test ( auto_id integer, first_name VARCHAR(100), last_name VARCHAR(200) NOT NULL, age INT DEFAULT 0, created VARCHAR(40) NOT NULL, happy char(1) DEFAULT 'Y' NOT NULL ) """ mysqlDrop = """ DROP TABLE IF EXISTS auto_test """ postgresDrop = """ DROP TABLE auto_test """ sybaseDrop = """ DROP TABLE auto_test """ _table = 'auto_test' def testClassCreate(self): if not self.supportAuto: return class AutoTest(SQLObject): _fromDatabase = True _idName = 'auto_id' _connection = connection() john = AutoTest(firstName='john', lastName='doe', age=10, created=DateTime.now(), wannahavefun=False) jane = AutoTest(firstName='jane', lastName='doe', happy='N', created=DateTime.now(), wannahavefun=True) self.failIf(john.wannahavefun) self.failUnless(jane.wannahavefun) del classregistry.registry(AutoTest._registry).classes['AutoTest'] ######################################## ## Joins ######################################## class PersonJoiner(SQLObject): _columns = [StringCol('name', length=40, alternateID=True)] _joins = [RelatedJoin('AddressJoiner')] class AddressJoiner(SQLObject): _columns = [StringCol('zip', length=5, alternateID=True)] _joins = [RelatedJoin('PersonJoiner')] class JoinTest(SQLObjectTest): classes = [PersonJoiner, AddressJoiner] def inserts(self): for n in ['bob', 'tim', 'jane', 'joe', 'fred', 'barb']: PersonJoiner(name=n) for z in ['11111', '22222', '33333', '44444']: AddressJoiner(zip=z) def testJoin(self): b = PersonJoiner.byName('bob') self.assertEqual(b.addressJoiners, []) z = AddressJoiner.byZip('11111') b.addAddressJoiner(z) self.assertZipsEqual(b.addressJoiners, ['11111']) self.assertNamesEqual(z.personJoiners, ['bob']) z2 = AddressJoiner.byZip('22222') b.addAddressJoiner(z2) self.assertZipsEqual(b.addressJoiners, ['11111', '22222']) self.assertNamesEqual(z2.personJoiners, ['bob']) b.removeAddressJoiner(z) self.assertZipsEqual(b.addressJoiners, ['22222']) self.assertNamesEqual(z.personJoiners, []) def assertZipsEqual(self, zips, dest): self.assertEqual([a.zip for a in zips], dest) def assertNamesEqual(self, people, dest): self.assertEqual([p.name for p in people], dest) class PersonJoiner2(SQLObject): _columns = [StringCol('name', length=40, alternateID=True)] _joins = [MultipleJoin('AddressJoiner2')] class AddressJoiner2(SQLObject): _columns = [StringCol('zip', length=5), StringCol('plus4', length=4, default=None), ForeignKey('PersonJoiner2')] _defaultOrder = ['-zip', 'plus4'] class JoinTest2(SQLObjectTest): classes = [PersonJoiner2, AddressJoiner2] def inserts(self): p1 = PersonJoiner2(name='bob') p2 = PersonJoiner2(name='sally') for z in ['11111', '22222', '33333']: a = AddressJoiner2(zip=z, personJoiner2=p1) #p1.addAddressJoiner2(a) AddressJoiner2(zip='00000', personJoiner2=p2) def test(self): bob = PersonJoiner2.byName('bob') sally = PersonJoiner2.byName('sally') self.assertEqual(len(bob.addressJoiner2s), 3) self.assertEqual(len(sally.addressJoiner2s), 1) bob.addressJoiner2s[0].destroySelf() self.assertEqual(len(bob.addressJoiner2s), 2) z = bob.addressJoiner2s[0] z.zip = 'xxxxx' id = z.id del z z = AddressJoiner2.get(id) self.assertEqual(z.zip, 'xxxxx') def testDefaultOrder(self): p1 = PersonJoiner2.byName('bob') self.assertEqual([i.zip for i in p1.addressJoiner2s], ['33333', '22222', '11111']) ######################################## ## Inheritance ######################################## class Super(SQLObject): _columns = [StringCol('name', length=10)] class Sub(Super): _columns = Super._columns + [StringCol('name2', length=10)] class InheritanceTest(SQLObjectTest): classes = [Super, Sub] def testSuper(self): s1 = Super(name='one') s2 = Super(name='two') s3 = Super.get(s1.id) self.assertEqual(s1, s3) def testSub(self): s1 = Sub(name='one', name2='1') s2 = Sub(name='two', name2='2') s3 = Sub.get(s1.id) self.assertEqual(s1, s3) ######################################## ## Expiring, syncing ######################################## class SyncTest(SQLObject): name = StringCol(length=50, alternateID=True, dbName='name_col') class ExpireTest(SQLObjectTest): classes = [SyncTest] def inserts(self): SyncTest(name='bob') SyncTest(name='tim') def testExpire(self): conn = SyncTest._connection b = SyncTest.byName('bob') conn.query("UPDATE sync_test SET name_col = 'robert' WHERE id = %i" % b.id) self.assertEqual(b.name, 'bob') b.expire() self.assertEqual(b.name, 'robert') conn.query("UPDATE sync_test SET name_col = 'bobby' WHERE id = %i" % b.id) b.sync() self.assertEqual(b.name, 'bobby') ######################################## ## Validation/conversion ######################################## class SOValidation(SQLObject): name = StringCol(validator=validators.PlainText(), default='x', dbName='name_col') name2 = StringCol(validator=validators.ConfirmType(str), default='y') name3 = IntCol(validator=validators.Wrapper(fromPython=int), default=100) class ValidationTest(SQLObjectTest): classes = [SOValidation] def testValidate(self): t = SOValidation(name='hey') self.assertRaises(validators.InvalidField, setattr, t, 'name', '!!!') t.name = 'you' def testConfirmType(self): t = SOValidation(name2='hey') self.assertRaises(validators.InvalidField, setattr, t, 'name2', 1) t.name2 = 'you' def testWrapType(self): t = SOValidation(name3=1) self.assertRaises(validators.InvalidField, setattr, t, 'name3', 'x') t.name3 = 1L self.assertEqual(t.name3, 1) t.name3 = '1' self.assertEqual(t.name3, 1) t.name3 = 0 self.assertEqual(t.name3, 0) ######################################## ## String ID test ######################################## class SOStringID(SQLObject): _table = 'so_string_id' _idType = str val = StringCol(alternateID=True) mysqlCreate = """ CREATE TABLE IF NOT EXISTS so_string_id ( id VARCHAR(50) PRIMARY KEY, val TEXT ) """ postgresCreate = """ CREATE TABLE so_string_id ( id VARCHAR(50) PRIMARY KEY, val TEXT ) """ sybaseCreate = """ CREATE TABLE so_string_id ( id VARCHAR(50) UNIQUE, val VARCHAR(50) NULL ) """ firebirdCreate = """ CREATE TABLE so_string_id ( id VARCHAR(50) NOT NULL PRIMARY KEY, val BLOB SUB_TYPE TEXT ) """ sqliteCreate = postgresCreate mysqlDrop = """ DROP TABLE IF EXISTS so_string_id """ postgresDrop = """ DROP TABLE so_string_id """ sqliteDrop = postgresDrop firebirdDrop = postgresDrop class StringIDTest(SQLObjectTest): classes = [SOStringID] def testStringID(self): t = SOStringID(id='hey', val='whatever') t2 = SOStringID.byVal('whatever') self.assertEqual(t, t2) t3 = SOStringID(id='you', val='nowhere') t4 = SOStringID.get('you') self.assertEqual(t3, t4) class AnotherStyle(MixedCaseUnderscoreStyle): def pythonAttrToDBColumn(self, attr): if attr.lower().endswith('id'): return 'id'+MixedCaseUnderscoreStyle.pythonAttrToDBColumn(self, attr[:-2]) else: return MixedCaseUnderscoreStyle.pythonAttrToDBColumn(self, attr) class SOStyleTest1(SQLObject): a = StringCol() st2 = ForeignKey('SOStyleTest2') _style = AnotherStyle() class SOStyleTest2(SQLObject): b = StringCol() _style = AnotherStyle() class StyleTest(SQLObjectTest): classes = [SOStyleTest2, SOStyleTest1] def test(self): st1 = SOStyleTest1(a='something', st2=None) st2 = SOStyleTest2(b='whatever') st1.st2 = st2 self.assertEqual(st1._SO_columnDict['st2ID'].dbName, 'idst2') self.assertEqual(st1.st2, st2) ######################################## ## Lazy updates ######################################## class Lazy(SQLObject): _lazyUpdate = True name = StringCol() other = StringCol(default='nothing') third = StringCol(default='third') class LazyTest(SQLObjectTest): classes = [Lazy] def setUp(self): # All this stuff is so that we can track when the connection # does an actual update; we put in a new _SO_update method # that calls the original and sets an instance variable that # we can later check. SQLObjectTest.setUp(self) self.conn = Lazy._connection self.conn.didUpdate = False self._oldUpdate = self.conn._SO_update newUpdate = lambda so, values, s=self, c=self.conn, o=self._oldUpdate: self._alternateUpdate(so, values, c, o) self.conn._SO_update = newUpdate def tearDown(self): self.conn._SO_update = self._oldUpdate del self._oldUpdate def _alternateUpdate(self, so, values, conn, oldUpdate): conn.didUpdate = True return oldUpdate(so, values) def test(self): assert not self.conn.didUpdate obj = Lazy(name='tim') # We just did an insert, but not an update: assert not self.conn.didUpdate obj.set(name='joe') assert obj.dirty self.assertEqual(obj.name, 'joe') assert not self.conn.didUpdate obj.syncUpdate() self.assertEqual(obj.name, 'joe') assert self.conn.didUpdate assert not obj.dirty self.assertEqual(obj.name, 'joe') self.conn.didUpdate = False obj = Lazy(name='frank') obj.name = 'joe' assert not self.conn.didUpdate assert obj.dirty self.assertEqual(obj.name, 'joe') obj.name = 'joe2' assert not self.conn.didUpdate assert obj.dirty self.assertEqual(obj.name, 'joe2') obj.syncUpdate() self.assertEqual(obj.name, 'joe2') assert not obj.dirty assert self.conn.didUpdate self.conn.didUpdate = False obj = Lazy(name='loaded') assert not obj.dirty assert not self.conn.didUpdate self.assertEqual(obj.name, 'loaded') obj.name = 'unloaded' assert obj.dirty self.assertEqual(obj.name, 'unloaded') assert not self.conn.didUpdate obj.sync() assert not obj.dirty self.assertEqual(obj.name, 'unloaded') assert self.conn.didUpdate self.conn.didUpdate = False obj.name = 'whatever' assert obj.dirty self.assertEqual(obj.name, 'whatever') assert not self.conn.didUpdate obj._SO_loadValue('name') assert obj.dirty self.assertEqual(obj.name, 'whatever') assert not self.conn.didUpdate obj._SO_loadValue('other') self.assertEqual(obj.name, 'whatever') assert not self.conn.didUpdate obj.syncUpdate() assert self.conn.didUpdate self.conn.didUpdate = False # Now, check that get() doesn't screw # cached objects' validator state. obj_id = obj.id old_state = obj._SO_validatorState obj = Lazy.get(obj_id) assert not obj.dirty assert not self.conn.didUpdate assert obj._SO_validatorState is old_state self.assertEqual(obj.name, 'whatever') obj.name = 'unloaded' self.assertEqual(obj.name, 'unloaded') assert obj.dirty assert not self.conn.didUpdate # Fetch the object again with get() and # make sure dirty is still set, as the # object should come from the cache. obj = Lazy.get(obj_id) assert obj.dirty assert not self.conn.didUpdate self.assertEqual(obj.name, 'unloaded') obj.syncUpdate() assert self.conn.didUpdate assert not obj.dirty self.conn.didUpdate = False # Then clear the cache, and try a get() # again, to make sure stuf like _SO_createdValues # is properly initialized. self.conn.cache.clear() obj = Lazy.get(obj_id) assert not obj.dirty assert not self.conn.didUpdate self.assertEqual(obj.name, 'unloaded') obj.name = 'spongebob' self.assertEqual(obj.name, 'spongebob') assert obj.dirty assert not self.conn.didUpdate obj.syncUpdate() assert self.conn.didUpdate assert not obj.dirty self.conn.didUpdate = False obj = Lazy(name='last') assert not obj.dirty obj.syncUpdate() assert not self.conn.didUpdate assert not obj.dirty # Check that setting multiple values # actually works. This was broken # and just worked because we were testing # only one value at a time, so 'name' # had the right value after the for loop *wink* # Also, check that passing a name that is not # a valid column doesn't break, but instead # just does a plain setattr. obj.set(name='first', other='who', third='yes', driver='james') self.assertEqual(obj.name, 'first') self.assertEqual(obj.other, 'who') self.assertEqual(obj.third, 'yes') self.assertEqual(obj.driver, 'james') assert obj.dirty assert not self.conn.didUpdate obj.syncUpdate() assert self.conn.didUpdate assert not obj.dirty ######################################## ## Run from command-line: ######################################## def coverModules(): sys.stdout.write('Writing coverage...') sys.stdout.flush() here = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) from SQLObject import DBConnection as tmp there = os.path.dirname(os.path.abspath(tmp.__file__)) for name, mod in sys.modules.items(): if not mod: continue try: modFile = os.path.abspath(mod.__file__) except AttributeError: # Probably a C extension continue if modFile.startswith(here) or modFile.startswith(there): writeCoverage(mod, there, os.path.join(here, 'SQLObject')) coverage.erase() sys.stdout.write('done.\n') def writeCoverage(module, oldBase, newBase): filename, numbers, unexecuted, s = coverage.analysis(module) coverFilename = filename + ',cover' if coverFilename.startswith(oldBase): coverFilename = newBase + coverFilename[len(oldBase):] fout = open(coverFilename, 'w') fin = open(filename) i = 1 lines = 0 good = 0 while 1: line = fin.readline() if not line: break assert line[-1] == '\n' fout.write(line[:-1]) unused = i in unexecuted interesting = interestingLine(line, unused) if interesting: if unused: fout.write(' '*(72-len(line))) fout.write('#@@@@') lastUnused = True else: lastUnused = False good += 1 lines += 1 fout.write('\n') i += 1 fout.write('\n# Coverage:\n') fout.write('# %i/%i, %i%%' % ( good, lines, lines and int(good*100/lines))) fout.close() fin.close() def interestingLine(line, unused): line = line.strip() if not line: return False if line.startswith('#'): return False if line in ('"""', '"""'): return False if line.startswith('global '): return False if line.startswith('def ') and not unused: # If a def *isn't* executed, that's interesting return False if line.startswith('class ') and not unused: return False return True if __name__ == '__main__': import unittest, sys, os dbs = [] newArgs = [] doCoverage = False for arg in sys.argv[1:]: if arg.startswith('-d'): dbs.append(arg[2:]) continue if arg.startswith('--database='): dbs.append(arg[11:]) continue if arg in ('-vv', '--extra-verbose'): SQLObjectTest.debugSQL = True if arg in ('-vvv', '--super-verbose'): SQLObjectTest.debugSQL = True SQLObjectTest.debugOutput = True newArgs.append('-vv') continue if arg in ('--inserts',): SQLObjectTest.debugInserts = True continue if arg in ('--coverage',): # Handled earlier, so we get better coverage doCoverage = True continue newArgs.append(arg) sys.argv = [sys.argv[0]] + newArgs if not dbs: dbs = ['mysql'] if dbs == ['all']: dbs = supportedDatabases() for db in dbs: curr_db = db setDatabaseType(db) print 'Testing %s' % db try: unittest.main() except SystemExit: pass if doCoverage: coverage.stop() coverModules()