from abc import ABC, abstractmethod

import strawberry
from sqlalchemy import and_, func, or_, select
from strawberry import UNSET

import hircine.db
from hircine.db.models import ComicTag
from hircine.enums import Category, Censorship, Language, Operator, Rating


class Matchable(ABC):
    """
    The filter interface is comprised of two methods, include and exclude, that
    can freely modify an SQL statement passed to them.
    """

    @abstractmethod
    def include(self, sql):
        return sql

    @abstractmethod
    def exclude(self, sql):
        return sql


@strawberry.input
class CountFilter:
    operator: Operator | None = Operator.EQUAL
    value: int

    def include(self, column, sql):
        return sql.where(self.operator.value(column, self.value))

    def exclude(self, column, sql):
        return sql.where(~self.operator.value(column, self.value))


@strawberry.input
class AssociationFilter(Matchable):
    any: list[int] | None = strawberry.field(default_factory=lambda: None)
    all: list[int] | None = strawberry.field(default_factory=lambda: None)
    exact: list[int] | None = strawberry.field(default_factory=lambda: None)
    count: CountFilter | None = UNSET

    def _exists(self, condition):
        # The property.primaryjoin expression specifies the primary join path
        # between the parent object of the column that was handed to the
        # Matchable instance and the associated object.
        #
        # For example, if this AssociationFilter is parametrized as
        # AssociationFilter[World], and is present on an input class that is
        # mapped to the Comic model, the primaryjoin expression is as follows:
        #
        #       comic.id = comic_worlds.comic_id
        #
        # This expression is used to correlate the subquery with the main query
        # for the parent object.
        #
        # condition specifies any additional conditions we should match on.
        # Usually these will come from the where generator, which correlates
        # the secondary objects with the user-supplied ids.
        return select(1).where((self.column.property.primaryjoin) & condition).exists()

    def _any_exist(self, items):
        return self._exists(or_(*self._collect(items)))

    def _where_any_exist(self, sql):
        return sql.where(self._any_exist(self.any))

    def _where_none_exist(self, sql):
        return sql.where(~self._any_exist(self.any))

    def _all_exist(self, items):
        return and_(self._exists(c) for c in self._collect(items))

    def _where_all_exist(self, sql):
        return sql.where(self._all_exist(self.all))

    def _where_not_all_exist(self, sql):
        return sql.where(~self._all_exist(self.all))

    def _count_of(self, column):
        return (
            select(func.count(column))
            .where(self.column.property.primaryjoin)
            .scalar_subquery()
        )

    def _exact(self):
        return and_(
            self._all_exist(self.exact),
            self._count_of(self.remote_column) == len(self.exact),
        )

    def _collect(self, ids):
        for id in ids:
            yield from self.where(id)

    @property
    def remote_column(self):
        _, remote = self.column.property.local_remote_pairs
        _, remote_column = remote

        return remote_column

    def where(self, id):
        yield self.remote_column == id

    def include(self, sql):
        # ignore if any/all is None, but when the user specifically includes an
        # empty list, make sure to return no items
        if self.any:
            sql = self._where_any_exist(sql)
        elif self.any == []:
            sql = sql.where(False)

        if self.all:
            sql = self._where_all_exist(sql)
        elif self.all == []:
            sql = sql.where(False)

        if self.count:
            sql = self.count.include(self.count_column, sql)

        if self.exact is not None:
            sql = sql.where(self._exact())

        return sql

    def exclude(self, sql):
        # in contrast to include() we can fully ignore if any/all is None or
        # the empty list and just return all items, since the user effectively
        # asks to exclude "nothing"
        if self.any:
            sql = self._where_none_exist(sql)
        if self.all:
            sql = self._where_not_all_exist(sql)

        if self.count:
            sql = self.count.exclude(self.count_column, sql)

        if self.exact is not None:
            sql = sql.where(~self._exact())

        return sql


@strawberry.input
class Root:
    def match(self, sql, negate):
        """
        Collect all relevant matchers from the input and construct the final
        SQL statement.

        If the matcher is a boolean value (like favourite, organized, etc), use
        it directly. Otherwise consult a Matchable's include or exclude method.
        """

        for field, matcher in self.__dict__.items():
            if matcher is UNSET:
                continue

            column = getattr(self._model, field, None)

            # count columns are historically singular, so we need this hack
            singular_field = field[:-1]
            count_column = getattr(self._model, f"{singular_field}_count", None)

            if issubclass(type(matcher), Matchable):
                matcher.column = column
                matcher.count_column = count_column

                if not negate:
                    sql = matcher.include(sql)
                else:
                    sql = matcher.exclude(sql)

            if isinstance(matcher, bool):
                if not negate:
                    sql = sql.where(column == matcher)
                else:
                    sql = sql.where(column != matcher)

        return sql


# When resolving names for generic types, strawberry prepends the name of the
# type variable to the name of the generic class. Since all classes that extend
# this class already end in "Filter", we have to make sure not to name it
# "FilterInput" lest we end up with "ComicFilterFilterInput".
#
# For now, use the very generic "Input" name so that we end up with sane
# GraphQL type names like "ComicFilterInput".
@strawberry.input
class Input[T]:
    include: T | None = UNSET
    exclude: T | None = UNSET


@strawberry.input
class StringFilter(Matchable):
    contains: str | None = UNSET

    def _conditions(self):
        if self.contains is not UNSET:
            yield self.column.contains(self.contains)

    def include(self, sql):
        conditions = list(self._conditions())
        if not conditions:
            return sql

        return sql.where(and_(*conditions))

    def exclude(self, sql):
        conditions = [~c for c in self._conditions()]
        if not conditions:
            return sql

        return sql.where(and_(*conditions))


@strawberry.input
class BasicCountFilter(Matchable):
    count: CountFilter

    def include(self, sql):
        return self.count.include(self.count_column, sql)

    def exclude(self, sql):
        return self.count.exclude(self.count_column, sql)


@strawberry.input
class TagAssociationFilter(AssociationFilter):
    """
    Tags need special handling since their IDs are strings instead of numbers.
    We can keep the full logic of AssociationFilter and only need to make sure
    we unpack the database IDs from the input IDs.
    """

    any: list[str] | None = strawberry.field(default_factory=lambda: None)
    all: list[str] | None = strawberry.field(default_factory=lambda: None)
    exact: list[str] | None = strawberry.field(default_factory=lambda: None)

    def where(self, id):
        try:
            nid, tid = id.split(":")
        except ValueError:
            # invalid specification, force False and stop generator
            yield False
            return

        predicates = []
        if nid:
            predicates.append(ComicTag.namespace_id == nid)
        if tid:
            predicates.append(ComicTag.tag_id == tid)

        if not predicates:
            # empty specification, force False and stop generator
            yield False
            return

        yield and_(*predicates)

    @property
    def remote_column(self):
        return ComicTag.comic_id


@strawberry.input
class Filter[T](Matchable):
    any: list[T] | None = strawberry.field(default_factory=lambda: None)
    empty: bool | None = None

    def _empty(self):
        if self.empty:
            return self.column.is_(None)
        else:
            return ~self.column.is_(None)

    def _any_exist(self):
        return self.column.in_(self.any)

    def include(self, sql):
        if self.any:
            sql = sql.where(self._any_exist())

        if self.empty is not None:
            sql = sql.where(self._empty())

        return sql

    def exclude(self, sql):
        if self.any:
            sql = sql.where(~self._any_exist())

        if self.empty is not None:
            sql = sql.where(~self._empty())

        return sql


@hircine.db.model("Comic")
@strawberry.input
class ComicFilter(Root):
    title: StringFilter | None = UNSET
    original_title: StringFilter | None = UNSET
    url: StringFilter | None = UNSET
    language: Filter[Language] | None = UNSET
    tags: TagAssociationFilter | None = UNSET
    artists: AssociationFilter | None = UNSET
    characters: AssociationFilter | None = UNSET
    circles: AssociationFilter | None = UNSET
    worlds: AssociationFilter | None = UNSET
    category: Filter[Category] | None = UNSET
    censorship: Filter[Censorship] | None = UNSET
    rating: Filter[Rating] | None = UNSET
    favourite: bool | None = UNSET
    organized: bool | None = UNSET
    bookmarked: bool | None = UNSET


@hircine.db.model("Archive")
@strawberry.input
class ArchiveFilter(Root):
    path: StringFilter | None = UNSET
    organized: bool | None = UNSET


@hircine.db.model("Artist")
@strawberry.input
class ArtistFilter(Root):
    name: StringFilter | None = UNSET
    comics: BasicCountFilter | None = UNSET


@hircine.db.model("Character")
@strawberry.input
class CharacterFilter(Root):
    name: StringFilter | None = UNSET
    comics: BasicCountFilter | None = UNSET


@hircine.db.model("Circle")
@strawberry.input
class CircleFilter(Root):
    name: StringFilter | None = UNSET
    comics: BasicCountFilter | None = UNSET


@hircine.db.model("Namespace")
@strawberry.input
class NamespaceFilter(Root):
    name: StringFilter | None = UNSET
    tags: BasicCountFilter | None = UNSET


@hircine.db.model("Tag")
@strawberry.input
class TagFilter(Root):
    name: StringFilter | None = UNSET
    namespaces: AssociationFilter | None = UNSET
    comics: BasicCountFilter | None = UNSET


@hircine.db.model("World")
@strawberry.input
class WorldFilter(Root):
    name: StringFilter | None = UNSET
    comics: BasicCountFilter | None = UNSET
