import asyncio
import hashlib
import logging
import os
import shutil
from asyncio.subprocess import Process
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

import click
import psutil
from anyio import (
    CancelScope,
    create_task_group,
    get_cancelled_exc_class,
    move_on_after,
    open_process,
)
from anyio.streams.text import TextReceiveStream
from watchfiles import Change, DefaultFilter, arun_process, awatch

from databutton.decorators.apps.streamlit import StreamlitApp
from databutton.decorators.jobs.schedule import Scheduler
from databutton.utils import get_databutton_components_path
from databutton.utils.build import generate_components, read_artifacts_json

logger = logging.getLogger("databutton.start")
awatch_logger = logging.getLogger("watchfiles.main")
awatch_logger.setLevel(logging.CRITICAL)


class DatabuttonFilter(DefaultFilter):
    def __init__(
        self,
        *,
        ignore_paths: Optional[Sequence[Union[str, Path]]] = None,
        extra_extensions: Sequence[str] = (),
        include_artifacts_json: bool = False,
    ) -> None:
        """
        Args:
            ignore_paths: The paths to ignore, see [`BaseFilter`][watchfiles.BaseFilter].
            extra_extensions: extra extensions to ignore.

        `ignore_paths` and `extra_extensions` can be passed as arguments partly to support [CLI](../cli.md) usage where
        `--ignore-paths` and `--extensions` can be passed as arguments.
        """
        self.extensions = (".py", ".pyx", ".pyd", ".pyc") + tuple(extra_extensions)
        self.include_artifacts_json = include_artifacts_json
        super().__init__(
            ignore_paths=ignore_paths,
            ignore_dirs=self.ignore_dirs + tuple([".databutton"]),
        )

    def __call__(self, change: "Change", path: str) -> bool:
        ret = (
            path.endswith(self.extensions)
            and super().__call__(change, path)
            and not path.endswith("artifacts.json")
        )
        if self.include_artifacts_json:
            ret = ret or path.endswith("artifacts.json")

        return ret


def get_components_hash():
    p = get_databutton_components_path()
    if not p.exists():
        return False
    md5 = hashlib.md5()
    with open(p, "r") as f:
        md5.update(f.read().encode("utf-8"))
    return md5.hexdigest()


class ComponentsJsonFilter(DefaultFilter):
    def __init__(self, starting_hash: str = None) -> None:
        super().__init__()
        self.prev_hash: Optional[str] = starting_hash

    def __call__(self, change: "Change", path: str) -> bool:
        should_call = super().__call__(change, path) and path.endswith("artifacts.json")
        if should_call:
            if not Path(path).exists():
                # Ignore if the file doesn't exist
                return False
            # Check hash extra check
            digest = get_components_hash()
            if digest == self.prev_hash:
                return False
            self.prev_hash = digest
            return True


class StreamlitWatcher:
    def __init__(self):
        self.apps: Dict[str, StreamlitApp] = {}
        self.processes: Dict[str, Process] = {}

    def get_stable_app_id(self, app: StreamlitApp):
        return f"{app.uid}-{app.port}"

    async def clear(self):
        await self.cancel()
        self.apps = {}
        self.processes = {}

    async def cancel(self):
        async with create_task_group() as tg:
            for key, process in self.processes.items():
                tg.start_soon(self.stop_process, process, self.apps.get(key))

    async def start_process(self, uid: str, app: StreamlitApp, task_status) -> Process:
        logger.debug(f"Starting process for {app.name}")
        cmd = f"""streamlit run {app.filename} \
                    --server.port={app.port} \
                    --server.headless=true \
                    --browser.gatherUsageStats=false \
                    --global.dataFrameSerialization=arrow \
                    --server.runOnSave=true \
                """
        # Set environment and force PYTHONPATH
        current_env = os.environ.copy()
        current_env["PYTHONPATH"] = "."
        process = await open_process(cmd, env=current_env)
        with move_on_after(5, shield=True) as scope:
            async for text in TextReceiveStream(process.stdout):
                logger.debug(text)
                if "You can now view your Streamlit app in your browser." in text:
                    task_status.started()
                    logger.debug(f"Started process for {app.name}")
                    self.processes[uid] = process
                    scope.cancel()
        return process

    async def stop_process(self, process: Process, app_filename: str):
        try:
            process.terminate()
            await process.wait()
            logger.debug(f"Stopped process {process.pid}")
        except:  # noqa
            logger.debug(f"Could not terminate process for {process.pid}.")
            # Ignore terminations, we'll nuke them all down below anyway

        # Streamlit has dangling processes, so let's find them and killem if need be
        for psprocess in psutil.process_iter():
            if psprocess.pid == process.pid:
                continue
            try:
                try:
                    cmdline = psprocess.cmdline()
                except (psutil.AccessDenied, psutil.NoSuchProcess, ProcessLookupError):
                    continue
                if app_filename in cmdline:
                    psprocess.kill()
                    logger.debug(f"Forcefully stopped psutil.process {psprocess.pid}")
            except:  # noqa
                logger.debug(
                    f"Could not terminate psutil.process {psprocess.pid}. "
                    + f"{psprocess}",
                )

    async def update_processes_from_apps(
        self, apps: List[StreamlitApp]
    ) -> List[asyncio.Task]:
        apps_map: Dict[str, StreamlitApp] = {
            self.get_stable_app_id(app): app for app in apps
        }
        previous = self.apps
        previous_apps_map = self.apps.copy()
        self.apps = apps_map
        old, new = set(uid for uid in previous.keys()), set(
            uid for uid in apps_map.keys()
        )
        new_apps = list(new - old)
        deleted_apps = list(old - new)

        async with create_task_group() as tg:
            for new_uid in new_apps:
                await tg.start(self.start_process, new_uid, apps_map.get(new_uid))

            for deleted_uid in deleted_apps:
                await self.stop_process(
                    self.processes.get(deleted_uid), previous_apps_map.get(deleted_uid)
                )

            for running_uid in new & old:
                new_app = apps_map.get(running_uid)
                old_app = previous_apps_map.get(running_uid)
                if old_app.uid != new_app.uid:
                    # This has a new port, we should restart it.
                    await self.stop_process(self.processes.get(running_uid), old_app)
                    await tg.start(self.start_process(running_uid, new_app))

        return len(new_apps) > 0 or len(deleted_apps) > 0


@dataclass
class DatabuttonConfig:
    port: int = os.environ.get("PORT", 8000)
    log_level: str = os.environ.get("LOG_LEVEL", "critical")


class GracefulExit(SystemExit):
    code = 1


class DatabuttonRunner:
    def __init__(self, root_dir=Path.cwd(), **config):
        self.root_dir = root_dir
        self.config = DatabuttonConfig(**config)
        self.initial_hash: str = None
        self.cancels = []

    async def create_webwatcher(self):
        args = [("port", self.config.port), ("log-level", self.config.log_level)]
        args_string = " ".join([f"--{arg}={value}" for arg, value in args])
        target_str = f"uvicorn {args_string} databutton.server.prod:app"
        return await arun_process(
            self.root_dir,
            target=target_str,
            target_type="command",
            watch_filter=ComponentsJsonFilter(starting_hash=self.initial_hash),
            callback=lambda _: click.secho("Restarting webserver..."),
        )

    async def create_streamlit_watcher(self):
        streamlit_watcher = StreamlitWatcher()
        self.cancels.append(streamlit_watcher.cancel)
        components = read_artifacts_json()
        try:
            await streamlit_watcher.update_processes_from_apps(
                components.streamlit_apps
            )
            async for _ in awatch(
                self.root_dir,
                watch_filter=ComponentsJsonFilter(starting_hash=self.initial_hash),
            ):
                new_components = read_artifacts_json()
                await streamlit_watcher.clear()
                await streamlit_watcher.update_processes_from_apps(
                    new_components.streamlit_apps
                )
        except get_cancelled_exc_class():
            with CancelScope(shield=True):
                await streamlit_watcher.clear()

            raise

    async def create_scheduler_watcher(self):
        return await arun_process(
            self.root_dir,
            watch_filter=DatabuttonFilter(include_artifacts_json=True),
            target=Scheduler.create,
            callback=lambda _: click.secho("Restarting scheduler..."),
        )

    async def create_components_watcher(self):
        return await arun_process(
            self.root_dir,
            watch_filter=DatabuttonFilter(),
            target=partial(generate_components, self.root_dir),
        )

    async def run(self, debug=False):
        if debug:
            logger.setLevel(logging.DEBUG)
        shutil.rmtree(Path(".databutton"), ignore_errors=True)
        generate_components(self.root_dir)
        self.initial_hash = get_components_hash()
        async with create_task_group() as tg:
            for task in self.create_tasks():
                tg.start_soon(task)
        click.secho("\nstopping...", fg="cyan")

    def create_tasks(self):
        return [
            self.create_components_watcher,
            self.create_streamlit_watcher,
            self.create_webwatcher,
            self.create_scheduler_watcher,
        ]
