import asyncio
import multiprocessing
import os
import platform
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime, timezone
from enum import Enum
from hashlib import file_digest
from typing import NamedTuple
from zipfile import BadZipFile, ZipFile, is_zipfile

from blake3 import blake3
from natsort import natsorted, ns
from sqlalchemy import insert, select, update
from sqlalchemy.dialects.sqlite import insert as sqlite_upsert
from sqlalchemy.orm import raiseload

import hircine.db as db
from hircine.db.models import Archive, Image, Page
from hircine.thumbnailer import Thumbnailer, params_from


class Status(Enum):
    NEW = "+"
    UNCHANGED = "="
    UPDATED = "*"
    RENAMED = ">"
    IGNORED = "I"
    CONFLICT = "!"
    MISSING = "?"
    REIMAGE = "~"


def log(status, path, renamed_to=None):
    if status == Status.UNCHANGED:
        return

    print(f"[{status.value}]", end=" ")
    print(f"{os.path.basename(path)}", end=" " if renamed_to else "\n")

    if renamed_to:
        print(f"-> {os.path.basename(renamed_to)}", end="\n")


class Registry:
    def __init__(self):
        self.paths = set()
        self.orphans = {}
        self.conflicts = {}
        self.marked = defaultdict(list)

    def mark(self, status, hash, path, renamed_to=None):
        log(status, path, renamed_to)
        self.marked[hash].append((path, status))

    @property
    def duplicates(self):
        for _, value in self.marked.items():
            if len(value) > 1:
                yield value


class Member(NamedTuple):
    path: str
    hash: str
    width: int
    height: int


class UpdateArchive(NamedTuple):
    id: int
    path: str
    mtime: datetime

    async def execute(self, session):
        await session.execute(
            update(Archive)
            .values(path=self.path, mtime=self.mtime)
            .where(Archive.id == self.id)
        )


class AddArchive(NamedTuple):
    hash: str
    path: str
    size: int
    mtime: datetime
    members: list[Member]

    async def upsert_images(self, session):
        input = [
            {
                "hash": member.hash,
                "width": member.width,
                "height": member.height,
            }
            for member in self.members
        ]

        images = {
            image.hash: image.id
            for image in await session.scalars(
                sqlite_upsert(Image)
                .returning(Image)
                .on_conflict_do_nothing(index_elements=["hash"]),
                input,
            )
        }

        missing = [member.hash for member in self.members if member.hash not in images]
        if missing:
            for image in await session.scalars(
                select(Image).where(Image.hash.in_(missing))
            ):
                images[image.hash] = image.id

        return images

    async def execute(self, session):
        images = await self.upsert_images(session)

        archive = (
            await session.scalars(
                insert(Archive).returning(Archive),
                {
                    "hash": self.hash,
                    "path": self.path,
                    "size": self.size,
                    "mtime": self.mtime,
                    "cover_id": images[self.members[0].hash],
                    "page_count": len(self.members),
                },
            )
        ).one()

        await session.execute(
            insert(Page),
            [
                {
                    "index": index,
                    "path": member.path,
                    "image_id": images[member.hash],
                    "archive_id": archive.id,
                }
                for index, member in enumerate(self.members)
            ],
        )


class Scanner:
    def __init__(self, config, dirs, reprocess=False):
        self.directory = dirs.scan
        self.thumbnailer = Thumbnailer(dirs.objects, params_from(config))
        self.registry = Registry()

        self.reprocess = reprocess

    async def scan(self):
        if platform.system() == "Windows":
            ctx = multiprocessing.get_context("spawn")  # pragma: no cover
        else:
            ctx = multiprocessing.get_context("forkserver")

        workers = multiprocessing.cpu_count() // 2

        with ProcessPoolExecutor(max_workers=workers, mp_context=ctx) as pool:
            async with db.session() as s:
                sql = select(Archive).options(raiseload(Archive.cover))

                for archive in await s.scalars(sql):
                    action = await self.scan_existing(archive, pool)

                    if action:
                        await action.execute(s)

                async for action in self.scan_dir(self.directory, pool):
                    await action.execute(s)

                await s.commit()

    def report(self):  # pragma: no cover
        if self.registry.orphans:
            print()
            print(
                "WARNING: The following paths are referenced in the DB, but do not exist in the file system:"  # noqa: E501
            )
            for orphan in self.registry.orphans.values():
                _, path = orphan
                log(Status.MISSING, path)

        for duplicate in self.registry.duplicates:
            print()
            print("WARNING: The following archives contain the same data:")
            for path, status in duplicate:
                log(status, path)

        for path, conflict in self.registry.conflicts.items():
            db_hash, fs_hash = conflict
            print()
            print("ERROR: The contents of the following archive have changed:")
            log(Status.CONFLICT, path)
            print(f"    Database: {db_hash}")
            print(f"    File system: {fs_hash}")

    async def scan_existing(self, archive, pool):
        try:
            stat = os.stat(archive.path, follow_symlinks=False)
        except FileNotFoundError:
            self.registry.orphans[archive.hash] = (archive.id, archive.path)
            return None

        self.registry.paths.add(archive.path)

        mtime = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc)

        if mtime == archive.mtime:
            if self.reprocess:
                await self.process_zip(archive.path, pool)

                self.registry.mark(Status.REIMAGE, archive.hash, archive.path)
                return None
            else:
                self.registry.mark(Status.UNCHANGED, archive.hash, archive.path)
                return None

        hash, _ = await self.process_zip(archive.path, pool)

        if archive.hash == hash:
            self.registry.mark(Status.UPDATED, archive.hash, archive.path)
            return UpdateArchive(id=archive.id, path=archive.path, mtime=mtime)
        else:
            log(Status.CONFLICT, archive.path)
            self.registry.conflicts[archive.path] = (archive.hash, hash)

        return None

    async def scan_dir(self, path, pool):
        path = os.path.realpath(path)

        for root, dirs, files in os.walk(path):
            for file in files:
                absolute = os.path.join(path, root, file)

                if os.path.islink(absolute):
                    continue

                if not is_zipfile(absolute):
                    continue

                if absolute in self.registry.paths:
                    continue

                async for result in self.scan_zip(absolute, pool):
                    yield result

    async def scan_zip(self, path, pool):
        stat = os.stat(path, follow_symlinks=False)
        mtime = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc)

        hash, members = await self.process_zip(path, pool)

        if hash in self.registry.marked:
            self.registry.mark(Status.IGNORED, hash, path)
            return

        if hash in self.registry.orphans:
            id, old_path = self.registry.orphans[hash]
            del self.registry.orphans[hash]

            self.registry.mark(Status.RENAMED, hash, old_path, renamed_to=path)
            yield UpdateArchive(id=id, path=path, mtime=mtime)
            return
        elif members:
            self.registry.mark(Status.NEW, hash, path)
            yield AddArchive(
                hash=hash,
                path=path,
                size=stat.st_size,
                mtime=mtime,
                members=natsorted(members, key=lambda m: m.path, alg=ns.P | ns.IC),
            )

    async def process_zip(self, path, pool):
        members = []
        hash = blake3()

        with ZipFile(path, mode="r") as z:
            try:
                z.testzip()
            except Exception as e:
                raise BadZipFile(f"Corrupt zip file {path}") from e

            input = [(path, info.filename) for info in z.infolist()]

        loop = asyncio.get_event_loop()

        tasks = [loop.run_in_executor(pool, self.process_member, i) for i in input]
        results = await asyncio.gather(*tasks)
        for digest, entry in results:
            hash.update(digest)
            if entry:
                members.append(entry)

        return hash.hexdigest(), members

    def process_member(self, input):
        path, name = input

        with ZipFile(path, mode="r") as ziph, ziph.open(name, mode="r") as member:
            _, ext = os.path.splitext(name)
            digest = file_digest(member, blake3).digest()

            if self.thumbnailer.can_process(ext):
                hash = digest.hex()

                width, height = self.thumbnailer.process(
                    member, hash, reprocess=self.reprocess
                )
                return digest, Member(
                    path=member.name, hash=hash, width=width, height=height
                )

        return digest, None
