from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, ne_
from sqlalchemy import util
import sqlalchemy as sa
from sqlalchemy.orm import class_mapper
from sqlalchemy.orm import attributes
from sqlalchemy.orm.attributes import set_attribute, \
    get_attribute, del_attribute
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.orm import clear_mappers
from sqlalchemy.testing import fixtures
from sqlalchemy.ext import instrumentation
from sqlalchemy.orm.instrumentation import register_class, manager_of_class
from sqlalchemy.testing.util import decorator
from sqlalchemy.orm import events
from sqlalchemy import event


@decorator
def modifies_instrumentation_finders(fn, *args, **kw):
    pristine = instrumentation.instrumentation_finders[:]
    try:
        fn(*args, **kw)
    finally:
        del instrumentation.instrumentation_finders[:]
        instrumentation.instrumentation_finders.extend(pristine)


class _ExtBase(object):
    @classmethod
    def teardown_class(cls):
        instrumentation._reinstall_default_lookups()


class MyTypesManager(instrumentation.InstrumentationManager):

    def instrument_attribute(self, class_, key, attr):
        pass

    def install_descriptor(self, class_, key, attr):
        pass

    def uninstall_descriptor(self, class_, key):
        pass

    def instrument_collection_class(self, class_, key, collection_class):
        return MyListLike

    def get_instance_dict(self, class_, instance):
        return instance._goofy_dict

    def initialize_instance_dict(self, class_, instance):
        instance.__dict__['_goofy_dict'] = {}

    def install_state(self, class_, instance, state):
        instance.__dict__['_my_state'] = state

    def state_getter(self, class_):
        return lambda instance: instance.__dict__['_my_state']


class MyListLike(list):
    # add @appender, @remover decorators as needed
    _sa_iterator = list.__iter__
    _sa_linker = None
    _sa_converter = None

    def _sa_appender(self, item, _sa_initiator=None):
        if _sa_initiator is not False:
            self._sa_adapter.fire_append_event(item, _sa_initiator)
        list.append(self, item)
    append = _sa_appender

    def _sa_remover(self, item, _sa_initiator=None):
        self._sa_adapter.fire_pre_remove_event(_sa_initiator)
        if _sa_initiator is not False:
            self._sa_adapter.fire_remove_event(item, _sa_initiator)
        list.remove(self, item)
    remove = _sa_remover


MyBaseClass, MyClass = None, None


class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):

    @classmethod
    def setup_class(cls):
        global MyBaseClass, MyClass

        class MyBaseClass(object):
            __sa_instrumentation_manager__ = \
                instrumentation.InstrumentationManager

        class MyClass(object):

            # This proves that a staticmethod will work here; don't
            # flatten this back to a class assignment!
            def __sa_instrumentation_manager__(cls):
                return MyTypesManager(cls)

            __sa_instrumentation_manager__ = staticmethod(
                __sa_instrumentation_manager__)

            # This proves SA can handle a class with non-string dict keys
            if not util.pypy and not util.jython:
                locals()[42] = 99   # Don't remove this line!

            def __init__(self, **kwargs):
                for k in kwargs:
                    setattr(self, k, kwargs[k])

            def __getattr__(self, key):
                if is_instrumented(self, key):
                    return get_attribute(self, key)
                else:
                    try:
                        return self._goofy_dict[key]
                    except KeyError:
                        raise AttributeError(key)

            def __setattr__(self, key, value):
                if is_instrumented(self, key):
                    set_attribute(self, key, value)
                else:
                    self._goofy_dict[key] = value

            def __hasattr__(self, key):
                if is_instrumented(self, key):
                    return True
                else:
                    return key in self._goofy_dict

            def __delattr__(self, key):
                if is_instrumented(self, key):
                    del_attribute(self, key)
                else:
                    del self._goofy_dict[key]

    def teardown(self):
        clear_mappers()

    def test_instance_dict(self):
        class User(MyClass):
            pass

        register_class(User)
        attributes.register_attribute(
            User, 'user_id', uselist=False, useobject=False)
        attributes.register_attribute(
            User, 'user_name', uselist=False, useobject=False)
        attributes.register_attribute(
            User, 'email_address', uselist=False, useobject=False)

        u = User()
        u.user_id = 7
        u.user_name = 'john'
        u.email_address = '[email protected]'
        eq_(
            u.__dict__,
            {
                '_my_state': u._my_state,
                '_goofy_dict': {
                    'user_id': 7, 'user_name': 'john',
                    'email_address': '[email protected]'}}
        )

    def test_basic(self):
        for base in (object, MyBaseClass, MyClass):
            class User(base):
                pass

            register_class(User)
            attributes.register_attribute(
                User, 'user_id', uselist=False, useobject=False)
            attributes.register_attribute(
                User, 'user_name', uselist=False, useobject=False)
            attributes.register_attribute(
                User, 'email_address', uselist=False, useobject=False)

            u = User()
            u.user_id = 7
            u.user_name = 'john'
            u.email_address = '[email protected]'

            eq_(u.user_id, 7)
            eq_(u.user_name, "john")
            eq_(u.email_address, "[email protected]")
            attributes.instance_state(u)._commit_all(
                attributes.instance_dict(u))
            eq_(u.user_id, 7)
            eq_(u.user_name, "john")
            eq_(u.email_address, "[email protected]")

            u.user_name = 'heythere'
            u.email_address = '[email protected]'
            eq_(u.user_id, 7)
            eq_(u.user_name, "heythere")
            eq_(u.email_address, "[email protected]")

    def test_deferred(self):
        for base in (object, MyBaseClass, MyClass):
            class Foo(base):
                pass

            data = {'a': 'this is a', 'b': 12}

            def loader(state, keys):
                for k in keys:
                    state.dict[k] = data[k]
                return attributes.ATTR_WAS_SET

            manager = register_class(Foo)
            manager.deferred_scalar_loader = loader
            attributes.register_attribute(
                Foo, 'a', uselist=False, useobject=False)
            attributes.register_attribute(
                Foo, 'b', uselist=False, useobject=False)

            if base is object:
                assert Foo not in \
                    instrumentation._instrumentation_factory._state_finders
            else:
                assert Foo in \
                    instrumentation._instrumentation_factory._state_finders

            f = Foo()
            attributes.instance_state(f)._expire(
                attributes.instance_dict(f), set())
            eq_(f.a, "this is a")
            eq_(f.b, 12)

            f.a = "this is some new a"
            attributes.instance_state(f)._expire(
                attributes.instance_dict(f), set())
            eq_(f.a, "this is a")
            eq_(f.b, 12)

            attributes.instance_state(f)._expire(
                attributes.instance_dict(f), set())
            f.a = "this is another new a"
            eq_(f.a, "this is another new a")
            eq_(f.b, 12)

            attributes.instance_state(f)._expire(
                attributes.instance_dict(f), set())
            eq_(f.a, "this is a")
            eq_(f.b, 12)

            del f.a
            eq_(f.a, None)
            eq_(f.b, 12)

            attributes.instance_state(f)._commit_all(
                attributes.instance_dict(f))
            eq_(f.a, None)
            eq_(f.b, 12)

    def test_inheritance(self):
        """tests that attributes are polymorphic"""

        for base in (object, MyBaseClass, MyClass):
            class Foo(base):
                pass

            class Bar(Foo):
                pass

            register_class(Foo)
            register_class(Bar)

            def func1(state, passive):
                return "this is the foo attr"

            def func2(state, passive):
                return "this is the bar attr"

            def func3(state, passive):
                return "this is the shared attr"
            attributes.register_attribute(Foo, 'element',
                                          uselist=False, callable_=func1,
                                          useobject=True)
            attributes.register_attribute(Foo, 'element2',
                                          uselist=False, callable_=func3,
                                          useobject=True)
            attributes.register_attribute(Bar, 'element',
                                          uselist=False, callable_=func2,
                                          useobject=True)

            x = Foo()
            y = Bar()
            assert x.element == 'this is the foo attr'
            assert y.element == 'this is the bar attr', y.element
            assert x.element2 == 'this is the shared attr'
            assert y.element2 == 'this is the shared attr'

    def test_collection_with_backref(self):
        for base in (object, MyBaseClass, MyClass):
            class Post(base):
                pass

            class Blog(base):
                pass

            register_class(Post)
            register_class(Blog)
            attributes.register_attribute(
                Post, 'blog', uselist=False,
                backref='posts', trackparent=True, useobject=True)
            attributes.register_attribute(
                Blog, 'posts', uselist=True,
                backref='blog', trackparent=True, useobject=True)
            b = Blog()
            (p1, p2, p3) = (Post(), Post(), Post())
            b.posts.append(p1)
            b.posts.append(p2)
            b.posts.append(p3)
            self.assert_(b.posts == [p1, p2, p3])
            self.assert_(p2.blog is b)

            p3.blog = None
            self.assert_(b.posts == [p1, p2])
            p4 = Post()
            p4.blog = b
            self.assert_(b.posts == [p1, p2, p4])

            p4.blog = b
            p4.blog = b
            self.assert_(b.posts == [p1, p2, p4])

            # assert no failure removing None
            p5 = Post()
            p5.blog = None
            del p5.blog

    def test_history(self):
        for base in (object, MyBaseClass, MyClass):
            class Foo(base):
                pass

            class Bar(base):
                pass

            register_class(Foo)
            register_class(Bar)
            attributes.register_attribute(
                Foo, "name", uselist=False, useobject=False)
            attributes.register_attribute(
                Foo, "bars", uselist=True, trackparent=True, useobject=True)
            attributes.register_attribute(
                Bar, "name", uselist=False, useobject=False)

            f1 = Foo()
            f1.name = 'f1'

            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), 'name'),
                (['f1'], (), ()))

            b1 = Bar()
            b1.name = 'b1'
            f1.bars.append(b1)
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), 'bars'),
                ([b1], [], []))

            attributes.instance_state(f1)._commit_all(
                attributes.instance_dict(f1))
            attributes.instance_state(b1)._commit_all(
                attributes.instance_dict(b1))

            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1),
                    'name'),
                ((), ['f1'], ()))
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1),
                    'bars'),
                ((), [b1], ()))

            f1.name = 'f1mod'
            b2 = Bar()
            b2.name = 'b2'
            f1.bars.append(b2)
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), 'name'),
                (['f1mod'], (), ['f1']))
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), 'bars'),
                ([b2], [b1], []))
            f1.bars.remove(b1)
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), 'bars'),
                ([b2], [], [b1]))

    def test_null_instrumentation(self):
        class Foo(MyBaseClass):
            pass
        register_class(Foo)
        attributes.register_attribute(
            Foo, "name", uselist=False, useobject=False)
        attributes.register_attribute(
            Foo, "bars", uselist=True, trackparent=True, useobject=True)

        assert Foo.name == attributes.manager_of_class(Foo)['name']
        assert Foo.bars == attributes.manager_of_class(Foo)['bars']

    def test_alternate_finders(self):
        """Ensure the generic finder front-end deals with edge cases."""

        class Unknown(object):
            pass

        class Known(MyBaseClass):
            pass

        register_class(Known)
        k, u = Known(), Unknown()

        assert instrumentation.manager_of_class(Unknown) is None
        assert instrumentation.manager_of_class(Known) is not None
        assert instrumentation.manager_of_class(None) is None

        assert attributes.instance_state(k) is not None
        assert_raises((AttributeError, KeyError),
                      attributes.instance_state, u)
        assert_raises((AttributeError, KeyError),
                      attributes.instance_state, None)

    def test_unmapped_not_type_error(self):
        """extension version of the same test in test_mapper.

        fixes #3408
        """
        assert_raises_message(
            sa.exc.ArgumentError,
            "Class object expected, got '5'.",
            class_mapper, 5
        )

    def test_unmapped_not_type_error_iter_ok(self):
        """extension version of the same test in test_mapper.

        fixes #3408
        """
        assert_raises_message(
            sa.exc.ArgumentError,
            r"Class object expected, got '\(5, 6\)'.",
            class_mapper, (5, 6)
        )


class FinderTest(_ExtBase, fixtures.ORMTest):

    def test_standard(self):
        class A(object):
            pass

        register_class(A)

        eq_(
            type(manager_of_class(A)),
            instrumentation.ClassManager)

    def test_nativeext_interfaceexact(self):
        class A(object):
            __sa_instrumentation_manager__ = \
                instrumentation.InstrumentationManager

        register_class(A)
        ne_(
            type(manager_of_class(A)),
            instrumentation.ClassManager)

    def test_nativeext_submanager(self):
        class Mine(instrumentation.ClassManager):
            pass

        class A(object):
            __sa_instrumentation_manager__ = Mine

        register_class(A)
        eq_(type(manager_of_class(A)), Mine)

    @modifies_instrumentation_finders
    def test_customfinder_greedy(self):
        class Mine(instrumentation.ClassManager):
            pass

        class A(object):
            pass

        def find(cls):
            return Mine

        instrumentation.instrumentation_finders.insert(0, find)
        register_class(A)
        eq_(type(manager_of_class(A)), Mine)

    @modifies_instrumentation_finders
    def test_customfinder_pass(self):
        class A(object):
            pass

        def find(cls):
            return None

        instrumentation.instrumentation_finders.insert(0, find)
        register_class(A)

        eq_(
            type(manager_of_class(A)),
            instrumentation.ClassManager)


class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest):

    def test_none(self):
        class A(object):
            pass
        register_class(A)

        mgr_factory = lambda cls: instrumentation.ClassManager(cls)

        class B(object):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
        register_class(B)

        class C(object):
            __sa_instrumentation_manager__ = instrumentation.ClassManager
        register_class(C)

    def test_single_down(self):
        class A(object):
            pass
        register_class(A)

        mgr_factory = lambda cls: instrumentation.ClassManager(cls)

        class B(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        assert_raises_message(
            TypeError, "multiple instrumentation implementations",
            register_class, B)

    def test_single_up(self):

        class A(object):
            pass
        # delay registration

        mgr_factory = lambda cls: instrumentation.ClassManager(cls)

        class B(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
        register_class(B)

        assert_raises_message(
            TypeError, "multiple instrumentation implementations",
            register_class, A)

    def test_diamond_b1(self):
        mgr_factory = lambda cls: instrumentation.ClassManager(cls)

        class A(object):
            pass

        class B1(A):
            pass

        class B2(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        class C(object):
            pass

        assert_raises_message(
            TypeError, "multiple instrumentation implementations",
            register_class, B1)

    def test_diamond_b2(self):
        mgr_factory = lambda cls: instrumentation.ClassManager(cls)

        class A(object):
            pass

        class B1(A):
            pass

        class B2(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        class C(object):
            pass

        register_class(B2)
        assert_raises_message(
            TypeError, "multiple instrumentation implementations",
            register_class, B1)

    def test_diamond_c_b(self):
        mgr_factory = lambda cls: instrumentation.ClassManager(cls)

        class A(object):
            pass

        class B1(A):
            pass

        class B2(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        class C(object):
            pass

        register_class(C)

        assert_raises_message(
            TypeError, "multiple instrumentation implementations",
            register_class, B1)


class ExtendedEventsTest(_ExtBase, fixtures.ORMTest):

    """Allow custom Events implementations."""

    @modifies_instrumentation_finders
    def test_subclassed(self):
        class MyEvents(events.InstanceEvents):
            pass

        class MyClassManager(instrumentation.ClassManager):
            dispatch = event.dispatcher(MyEvents)

        instrumentation.instrumentation_finders.insert(
            0, lambda cls: MyClassManager)

        class A(object):
            pass

        register_class(A)
        manager = instrumentation.manager_of_class(A)
        assert issubclass(manager.dispatch._events, MyEvents)