from typing import Optional

import hircine.api.filters as filters
import hircine.api.sort as sort
import hircine.api.types as types
import hircine.db as db
import hircine.db.models as models
import hircine.db.ops as ops
import hircine.plugins as plugins
from hircine.api.filters import Input as FilterInput
from hircine.api.inputs import Pagination
from hircine.api.responses import (
    ComicTotals,
    IDNotFoundError,
    ScraperError,
    ScraperNotAvailableError,
    ScraperNotFoundError,
    Statistics,
    Totals,
)
from hircine.api.sort import Input as SortInput
from hircine.api.types import (
    ComicScraper,
    ComicTag,
    FilterResult,
    FullComic,
    ScrapeComicResult,
    ScrapedComic,
)
from hircine.scraper import ScrapeError

# Query resolvers use the factory pattern. Given a model, the factory will
# return a strawberry resolver that is passed the corresponding IDs


def single(model, full=False):
    modelname = model.__name__
    if full:
        modelname = f"Full{modelname}"

    typecls = getattr(types, modelname)

    async def inner(id: int):
        async with db.session() as s:
            options = model.load_full() if full else []
            obj = await s.get(model, id, options=options)

        if not obj:
            return IDNotFoundError(model, id)

        return typecls(obj)

    return inner


def every(model):
    typecls = getattr(types, model.__name__)
    filtercls = getattr(filters, f"{model.__name__}Filter")
    sortcls = getattr(sort, f"{model.__name__}Sort")

    async def inner(
        pagination: Optional[Pagination] = None,
        filter: Optional[FilterInput[filtercls]] = None,
        sort: Optional[SortInput[sortcls]] = None,
    ):
        async with db.session() as s:
            count, objs = await ops.query_all(
                s, model, pagination=pagination, filter=filter, sort=sort
            )

        return FilterResult(count=count, edges=[typecls(obj) for obj in objs])

    return inner


def namespace_tag_combinations_for(namespaces, tags, restrictions):
    for namespace in namespaces:
        for tag in tags:
            valid_ids = restrictions.get(tag.id, [])

            if namespace.id in valid_ids:
                yield ComicTag(namespace=namespace, tag=tag)


async def comic_tags(for_filter: bool = False):
    async with db.session() as s:
        _, tags = await ops.query_all(s, models.Tag)
        _, namespaces = await ops.query_all(s, models.Namespace)
        restrictions = await ops.tag_restrictions(s)

    combinations = list(namespace_tag_combinations_for(namespaces, tags, restrictions))

    if not for_filter:
        return FilterResult(count=len(combinations), edges=combinations)

    matchers = []

    for namespace in namespaces:
        matchers.append(ComicTag(namespace=namespace))
    for tag in tags:
        matchers.append(ComicTag(tag=tag))

    matchers.extend(combinations)

    return FilterResult(count=len(matchers), edges=matchers)


async def comic_scrapers(id: int):
    async with db.session() as s:
        comic = await s.get(models.Comic, id, options=models.Comic.load_full())

        if not comic:
            return []

    scrapers = []
    for id, cls in sorted(plugins.get_scrapers(), key=lambda p: p[1].name):
        scraper = cls(comic)
        if scraper.is_available:
            scrapers.append(ComicScraper(id, scraper))

    return scrapers


async def scrape_comic(id: int, scraper: str):
    scrapercls = plugins.get_scraper(scraper)

    if not scrapercls:
        return ScraperNotFoundError(name=scraper)

    async with db.session() as s:
        comic = await s.get(models.Comic, id, options=models.Comic.load_full())

        if not comic:
            return IDNotFoundError(models.Comic, id)

    instance = scrapercls(FullComic(comic))

    if not instance.is_available:
        return ScraperNotAvailableError(scraper=scraper, comic_id=id)

    gen = instance.collect(plugins.transformers)

    try:
        return ScrapeComicResult(
            data=ScrapedComic.from_generator(gen),
            warnings=instance.get_warnings(),
        )
    except ScrapeError as e:
        return ScraperError(error=str(e))


async def statistics():
    async with db.session() as s:
        total = Totals(
            archives=await ops.count(s, models.Archive),
            artists=await ops.count(s, models.Artist),
            characters=await ops.count(s, models.Character),
            circles=await ops.count(s, models.Circle),
            comic=ComicTotals(
                artists=await ops.count(s, models.ComicArtist),
                characters=await ops.count(s, models.ComicCharacter),
                circles=await ops.count(s, models.ComicCircle),
                tags=await ops.count(s, models.ComicTag),
                worlds=await ops.count(s, models.ComicWorld),
            ),
            comics=await ops.count(s, models.Comic),
            images=await ops.count(s, models.Image),
            namespaces=await ops.count(s, models.Namespace),
            pages=await ops.count(s, models.Page),
            scrapers=len(plugins.get_scrapers()),
            tags=await ops.count(s, models.Tag),
            worlds=await ops.count(s, models.World),
        )

    return Statistics(total=total)
