import random
from collections import defaultdict

from sqlalchemy import delete, func, literal_column, null, select, text, tuple_
from sqlalchemy.orm import contains_eager, undefer
from sqlalchemy.orm.util import identity_key
from strawberry import UNSET

from hircine.db.models import (
    Archive,
    ComicTag,
    Image,
    Namespace,
    Page,
    Tag,
    TagNamespaces,
)


def paginate(sql, pagination):
    if not pagination:
        return sql

    if pagination.items < 1 or pagination.page < 1:
        return sql.limit(0)

    sql = sql.limit(pagination.items)

    if pagination.page > 0:
        sql = sql.offset((pagination.page - 1) * pagination.items)

    return sql


def apply_filter(sql, filter):
    if not filter:
        return sql

    if filter.include is not UNSET:
        sql = filter.include.match(sql, False)
    if filter.exclude is not UNSET:
        sql = filter.exclude.match(sql, True)

    return sql


def sort_random(seed):
    if seed:
        seed = seed % 1000000000
    else:
        seed = random.randrange(1000000000)

    # https://www.sqlite.org/forum/forumpost/e2216583a4
    return text("sin(iid + :seed)").bindparams(seed=seed)


def apply_sort(sql, sort, default, tiebreaker):
    if not sort:
        return sql.order_by(*default, tiebreaker)

    direction = sort.direction.value

    if sort.on.value == "Random":
        return sql.order_by(direction(sort_random(sort.seed)))

    sql = sql.options(undefer(sort.on.value))

    return sql.order_by(direction(sort.on.value), tiebreaker)


async def query_all(session, model, pagination=None, filter=None, sort=None):
    sql = select(
        model, func.count(model.id).over().label("count"), model.id.label("iid")
    )
    sql = apply_filter(sql, filter)
    sql = apply_sort(sql, sort, model.default_order(), model.id)
    sql = paginate(sql, pagination)

    count = 0
    objs = []

    for row in await session.execute(sql):
        if count == 0:
            count = row.count

        objs.append(row[0])

    return count, objs


async def has_with_name(session, model, name):
    sql = select(model.id).where(model.name == name)
    return bool((await session.scalars(sql)).unique().first())


async def tag_restrictions(session, tuples=None):
    sql = select(TagNamespaces)

    if tuples:
        sql = sql.where(
            tuple_(TagNamespaces.namespace_id, TagNamespaces.tag_id).in_(tuples)
        )

    namespaces = (await session.scalars(sql)).unique().all()

    ns_map = defaultdict(set)

    for n in namespaces:
        ns_map[n.tag_id].add(n.namespace_id)

    return ns_map


def lookup_identity(session, model, ids):
    objects = []
    satisfied = set()

    for id in ids:
        obj = session.identity_map.get(identity_key(model, id), None)
        if obj is not None:
            objects.append(obj)
            satisfied.add(id)

    return objects, satisfied


async def get_all(session, model, ids, options=None, use_identity_map=False):
    if not options:
        options = []

    objects = []
    ids = set(ids)

    if use_identity_map:
        objects, satisfied = lookup_identity(session, model, ids)

        ids = ids - satisfied

        if not ids:
            return objects, set()

    sql = select(model).where(model.id.in_(ids)).options(*options)

    objects += (await session.scalars(sql)).unique().all()

    fetched_ids = [obj.id for obj in objects]
    missing = set(ids) - set(fetched_ids)

    return objects, missing


async def get_all_names(session, model, names, options=None):
    if not options:
        options = []

    names = set(names)

    sql = select(model).where(model.name.in_(names)).options(*options)

    objects = (await session.scalars(sql)).unique().all()

    fetched_names = [obj.name for obj in objects]
    missing = set(names) - set(fetched_names)

    return objects, missing


async def get_ctag_names(session, comic_id, tuples):
    sql = (
        select(ComicTag)
        .join(ComicTag.namespace)
        .options(contains_eager(ComicTag.namespace))
        .join(ComicTag.tag)
        .options(contains_eager(ComicTag.tag))
        .where(ComicTag.comic_id == comic_id)
        .where(tuple_(Namespace.name, Tag.name).in_(tuples))
    )
    objects = (await session.scalars(sql)).unique().all()

    fetched_tags = [(o.namespace.name, o.tag.name) for o in objects]
    missing = set(tuples) - set(fetched_tags)

    return objects, missing


async def get_image_orphans(session):
    sql = select(Image.id, Image.hash).join(Page, isouter=True).where(Page.id == null())

    return (await session.execute(sql)).t


async def get_remaining_pages_for(session, archive_id):
    sql = (
        select(Page.id)
        .join(Archive)
        .where(Archive.id == archive_id)
        .where(Page.comic_id == null())
    )

    return (await session.execute(sql)).scalars().all()


async def delete_all(session, model, ids):
    result = await session.execute(delete(model).where(model.id.in_(ids)))

    return result.rowcount


async def count(session, model):
    sql = select(func.count(literal_column("1"))).select_from(model)

    return (await session.execute(sql)).scalar_one()
