from datetime import datetime, timezone
from pathlib import Path

from strawberry import UNSET

import hircine.db as db
import hircine.db.ops as ops
import hircine.thumbnailer as thumb
from hircine.api import APIException, MutationContext
from hircine.api.inputs import (
    Fetchable,
    add_input_cls,
    update_input_cls,
    upsert_input_cls,
)
from hircine.api.responses import (
    AddComicSuccess,
    AddSuccess,
    DeleteSuccess,
    IDNotFoundError,
    InvalidParameterError,
    NameExistsError,
    UpdateSuccess,
    UpsertSuccess,
)
from hircine.config import get_dir_structure
from hircine.db.models import Comic, Image, MixinModifyDates
from hircine.enums import UpdateMode


async def fetch_fields(input, ctx: MutationContext):
    """
    Given a mutation input and a context, fetch and yield all relevant objects
    from the database.

    If the item requested is a Fetchable input, await its resolution, otherwise
    use the item "verbatim" after checking any API restrictions.
    """

    for field, value in input.__dict__.items():
        if field == "id" or value == UNSET:
            continue

        if issubclass(type(value), Fetchable):
            yield field, await value.fetch(ctx), value.update_mode()
        else:
            if isinstance(value, str) and not value:
                value = None

            await check_constraints(ctx, field, value)
            yield field, value, UpdateMode.REPLACE


async def check_constraints(ctx, field, value):
    column = getattr(ctx.model.__table__.c, field)

    if value is None and not column.nullable:
        raise APIException(
            InvalidParameterError(parameter=field, text="cannot be empty")
        )

    if column.unique and ctx.multiple:
        raise APIException(
            InvalidParameterError(
                parameter="name", text="Cannot bulk-update unique fields"
            )
        )

    if column.unique and field == "name":
        if value != ctx.root.name:
            if await ops.has_with_name(ctx.session, ctx.model, value):
                raise APIException(NameExistsError(ctx.model))


# Mutation resolvers use the factory pattern. Given a modelcls, the factory
# will return a strawberry resolver that is passed the corresponding Input
# type.


def add(modelcls, post_add=None):
    async def inner(input: add_input_cls(modelcls)):
        returnval = None

        async with db.session() as s:
            try:
                obj = modelcls()
                ctx = MutationContext(input, obj, s)

                async for field, value, _ in fetch_fields(input, ctx):
                    setattr(obj, field, value)
            except APIException as e:
                return e.graphql_error

            s.add(obj)
            await s.flush()

            if post_add:
                returnval = await post_add(s, input, obj)

            await s.commit()

        if returnval:
            return returnval
        else:
            return AddSuccess(modelcls, obj.id)

    return inner


async def post_add_comic(session, input, comic):
    remaining_pages = await ops.get_remaining_pages_for(session, input.archive.id)
    has_remaining = len(remaining_pages) > 0

    if not has_remaining:
        comic.archive.organized = True

    return AddComicSuccess(Comic, comic.id, has_remaining)


def update_attr(obj, field, value, mode):
    if mode != UpdateMode.REPLACE and isinstance(value, list):
        attr = getattr(obj, field)
        match mode:
            case UpdateMode.ADD:
                value.extend(attr)
            case UpdateMode.REMOVE:
                value = list(set(attr) - set(value))

    setattr(obj, field, value)


async def _update(ids: list[int], modelcls, input, successcls):
    multiple = len(ids) > 1

    async with db.session() as s:
        needed = [k for k, v in input.__dict__.items() if v is not UNSET]

        objects, missing = await ops.get_all(
            s, modelcls, ids, modelcls.load_update(needed)
        )

        if missing:
            return IDNotFoundError(modelcls, missing.pop())

        for obj in objects:
            s.add(obj)

            try:
                ctx = MutationContext(input, obj, s, multiple=multiple)

                async for field, value, mode in fetch_fields(input, ctx):
                    update_attr(obj, field, value, mode)
            except APIException as e:
                return e.graphql_error

            if isinstance(obj, MixinModifyDates) and s.is_modified(obj):
                obj.updated_at = datetime.now(tz=timezone.utc)

        await s.commit()

    return successcls()


def update(modelcls):
    async def inner(ids: list[int], input: update_input_cls(modelcls)):
        return await _update(ids, modelcls, input, UpdateSuccess)

    return inner


def upsert(modelcls):
    async def inner(ids: list[int], input: upsert_input_cls(modelcls)):
        return await _update(ids, modelcls, input, UpsertSuccess)

    return inner


def delete(modelcls, post_delete=None):
    async def inner(ids: list[int]):
        async with db.session() as s:
            objects, missing = await ops.get_all(s, modelcls, ids)
            if missing:
                return IDNotFoundError(modelcls, missing.pop())

            for obj in objects:
                await s.delete(obj)

            await s.flush()

            if post_delete:
                await post_delete(s, objects)

            await s.commit()

        return DeleteSuccess()

    return inner


async def post_delete_archive(session, objects):
    for archive in objects:
        Path(archive.path).unlink(missing_ok=True)

    dirs = get_dir_structure()
    orphans = await ops.get_image_orphans(session)

    ids = []
    for id, hash in orphans:
        ids.append(id)
        for suffix in ["full", "thumb"]:
            Path(thumb.object_path(dirs.objects, hash, suffix)).unlink(missing_ok=True)

    if not ids:
        return

    await ops.delete_all(session, Image, ids)
