# Copyright 2024 Marimo. All rights reserved.
"""
Interactive matplotlib plots, based on WebAgg.

Adapted from https://matplotlib.org/stable/gallery/user_interfaces/embedding_webagg_sgskip.html
"""

from __future__ import annotations

import asyncio
import html
import io
import mimetypes
import os
import threading
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

from marimo import _loggers
from marimo._output.builder import h
from marimo._output.formatting import as_html
from marimo._output.hypertext import Html
from marimo._output.rich_help import mddoc
from marimo._runtime.cell_lifecycle_item import CellLifecycleItem
from marimo._runtime.context import (
    RuntimeContext,
    get_context,
)
from marimo._runtime.context.kernel_context import KernelRuntimeContext
from marimo._runtime.runtime import app_meta
from marimo._server.utils import find_free_port
from marimo._utils.platform import is_pyodide

LOGGER = _loggers.marimo_logger()

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.backends.backend_webagg_core import FigureManagerWebAgg
    from matplotlib.figure import Figure, SubFigure
    from starlette.applications import Starlette
    from starlette.requests import Request
    from starlette.responses import HTMLResponse, Response
    from starlette.websockets import WebSocket


class FigureManagers:
    def __init__(self) -> None:
        self.figure_managers: dict[str, FigureManagerWebAgg] = {}

    def add(self, manager: FigureManagerWebAgg) -> None:
        self.figure_managers[str(manager.num)] = manager

    def get(self, figure_id: str) -> FigureManagerWebAgg:
        if figure_id not in self.figure_managers:
            raise RuntimeError(f"Figure {figure_id} not found.")  # noqa: E501
        return self.figure_managers[str(figure_id)]

    def remove(self, manager: FigureManagerWebAgg) -> None:
        try:
            del self.figure_managers[str(manager.num)]
        except KeyError:
            # Figure already removed, this can happen during server restart
            LOGGER.debug(f"Figure {manager.num} already removed from manager")


figure_managers = FigureManagers()


class MplServerManager:
    """Manages the matplotlib server lifecycle with lazy recovery."""

    def __init__(self) -> None:
        self.process: Optional[threading.Thread] = None
        self._restart_lock = threading.Lock()

    def is_running(self) -> bool:
        """Check if the server thread is still running."""
        if self.process is None:
            return False
        # Check if the thread is still alive
        return self.process.is_alive()

    def start(
        self,
        app_host: Optional[str] = None,
        free_port: Optional[int] = None,
        secure_host: Optional[bool] = None,
    ) -> Starlette:
        """Start the matplotlib server and return the Starlette app."""
        import uvicorn

        host = app_host if app_host is not None else _get_host()
        secure = secure_host if secure_host is not None else _get_secure()

        # Find a free port, with some randomization to avoid conflicts
        import random

        base_port = 10_000 + random.randint(0, 1000)  # Add some randomization
        port = (
            free_port if free_port is not None else find_free_port(base_port)
        )
        app = create_application()
        app.state.host = host
        app.state.port = port
        app.state.secure = secure

        def start_server() -> None:
            # Don't try to set signal handlers in background thread
            # The original signal handlers will remain in place
            server = uvicorn.Server(
                uvicorn.Config(
                    app=app,
                    port=port,
                    host=host,
                    log_level="critical",
                )
            )
            try:
                server.run()
            except Exception as e:
                LOGGER.error(f"Matplotlib server failed: {e}")
                # Thread will exit, making is_running() return False
                # This allows for automatic restart on next use

        # Start server in background thread
        thread = threading.Thread(target=start_server, daemon=True)
        thread.start()

        # Store thread reference to track server
        self.process = thread

        # TODO: Consider if we need this sleep from original code
        # Original comment: "arbitrary wait 200ms for the server to start"
        # With lazy recovery, this may no longer be necessary
        time.sleep(0.02)

        LOGGER.info(f"Started matplotlib server at {host}:{port}")
        return app

    def stop(self) -> None:
        """Stop the server process."""
        if self.process is not None:
            # Note: We can't easily terminate uvicorn server from here,
            # but marking process as None will cause is_running() to return False
            # and trigger a restart on next use
            self.process = None
            LOGGER.debug("Marked matplotlib server for restart")


_server_manager = MplServerManager()


def _get_host() -> str:
    """
    Get the host from environment variable or fall back to localhost.
    """
    host = os.environ.get("MARIMO_MPL_HOST", "localhost")
    if not host or not isinstance(host, str):
        return "localhost"
    if "://" in host:
        raise ValueError(
            f"Invalid host '{host}': should not include protocol (http:// or https://)"
        )
    if "/" in host:
        raise ValueError(f"Invalid host '{host}': should not include paths")
    if ":" in host:
        raise ValueError(
            f"Invalid host '{host}': should not include port numbers"
        )
    return host


def _get_secure() -> bool:
    """
    Get the secure status from environment variable or fall back to False.
    """
    secure = os.environ.get("MARIMO_MPL_SECURE", "false")
    if not secure or not isinstance(secure, str):
        return False
    secure = secure.lower().strip()
    if secure in ("true", "1", "yes", "on"):
        return True
    if secure in ("false", "0", "no", "off"):
        return False

    raise ValueError(
        f"Invalid secure value '{secure}': should be 'true' or 'false'"
    )


def _get_remote_url() -> str:
    request = app_meta().request
    if not request:
        return ""

    base_url = request.headers.get("x-runtime-url")
    if not base_url:
        return ""
    return base_url.rstrip("/")


def _convert_scheme_to_ws(url: str) -> str:
    if url.startswith("http://"):
        return url.replace("http://", "ws://")
    if url.startswith("https://"):
        return url.replace("https://", "wss://")
    return url


def _template(fig_id: str, port: int) -> str:
    base_url = _get_remote_url() or f"http://localhost:{port}/"
    base_url_and_path = f"{base_url}/mpl/{fig_id}"
    base_url_and_path_ws = f"{base_url}/mpl/{port}"
    ws_base_url = _convert_scheme_to_ws(base_url_and_path_ws)

    return html_content % {
        "ws_uri": f"{ws_base_url}/ws?figure={fig_id}",
        "fig_id": fig_id,
        "base_url": base_url_and_path,
    }


# Toplevel for reuse in endpoints.
async def mpl_js(request: Request) -> Response:
    from matplotlib.backends.backend_webagg_core import (
        FigureManagerWebAgg,
    )
    from starlette.responses import Response

    del request
    return Response(
        content=patch_javascript(FigureManagerWebAgg.get_javascript()),  # type: ignore[no-untyped-call]
        media_type="application/javascript",
    )


async def mpl_custom_css(request: Request) -> Response:
    from starlette.responses import Response

    del request
    return Response(
        content=css_content,
        media_type="text/css",
    )


# Over all application for handling figures on a per kernel basis
def create_application() -> Starlette:
    import matplotlib as mpl
    from matplotlib.backends.backend_webagg_core import (
        FigureManagerWebAgg,
    )
    from starlette.applications import Starlette
    from starlette.responses import HTMLResponse, Response
    from starlette.routing import Mount, Route, WebSocketRoute
    from starlette.staticfiles import StaticFiles
    from starlette.websockets import (
        WebSocketDisconnect,
        WebSocketState,
    )

    async def main_page(request: Request) -> HTMLResponse:
        figure_id = request.query_params.get("figure")
        assert figure_id is not None
        port = request.app.state.port
        content = _template(figure_id, port)
        return HTMLResponse(content=content)

    async def download(request: Request) -> Response:
        figure_id = request.query_params.get("figure")
        assert figure_id is not None
        fmt = request.path_params["fmt"]
        mime_type = mimetypes.types_map.get(fmt, "binary")
        buff = io.BytesIO()
        figure_manager = figure_managers.get(figure_id)
        figure_manager.canvas.figure.savefig(
            buff, format=fmt, bbox_inches="tight"
        )
        return Response(content=buff.getvalue(), media_type=mime_type)

    async def websocket_endpoint(websocket: WebSocket) -> None:
        await websocket.accept()
        queue: asyncio.Queue[tuple[Any, str]] = asyncio.Queue()

        class SyncWebSocket:
            def send_json(self, content: str) -> None:
                queue.put_nowait((content, "json"))

            def send_binary(self, blob: Any) -> None:
                queue.put_nowait((blob, "binary"))

        figure_id = websocket.query_params.get("figure")
        if not figure_id:
            await websocket.send_json(
                {"type": "error", "message": "No figure ID provided"}
            )
            await websocket.close()
            return

        try:
            figure_manager = figure_managers.get(figure_id)
        except RuntimeError:
            await websocket.send_json(
                {
                    "type": "error",
                    "message": f"Figure with id '{figure_id}' not found. The matplotlib server may have restarted. Please re-run the cell containing this plot.",
                }
            )
            await websocket.close()
            return

        figure_manager.add_web_socket(SyncWebSocket())  # type: ignore[no-untyped-call]

        async def receive() -> None:
            try:
                while True:
                    data = await websocket.receive_json()
                    if data["type"] == "supports_binary":
                        # We always support binary
                        # and we don't need to pass this message
                        # to the figure manager
                        pass
                    else:
                        figure_manager.handle_json(data)  # type: ignore[no-untyped-call]
            except WebSocketDisconnect:
                pass
            except Exception as e:
                if websocket.application_state != WebSocketState.DISCONNECTED:
                    await websocket.send_json(
                        {
                            "type": "error",
                            "message": f"WebSocket receive error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
                        }
                    )
            finally:
                if websocket.application_state != WebSocketState.DISCONNECTED:
                    await websocket.close()

        async def send() -> None:
            try:
                while True:
                    (data, mode) = await queue.get()
                    if mode == "json":
                        await websocket.send_json(data)
                    else:
                        await websocket.send_bytes(data)
            except WebSocketDisconnect:
                # Client disconnected normally
                pass
            except Exception as e:
                if websocket.application_state != WebSocketState.DISCONNECTED:
                    await websocket.send_json(
                        {
                            "type": "error",
                            "message": f"WebSocket send error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
                        }
                    )
            finally:
                if websocket.application_state != WebSocketState.DISCONNECTED:
                    await websocket.close()

        try:
            await asyncio.gather(receive(), send())
        except Exception as e:
            if websocket.application_state != WebSocketState.DISCONNECTED:
                await websocket.send_json(
                    {
                        "type": "error",
                        "message": f"WebSocket connection error: {str(e)}. The matplotlib server may have restarted. Please refresh this plot.",
                    }
                )
                await websocket.close()

    return Starlette(
        routes=[
            Route("/", main_page, methods=["GET"]),
            Route("/mpl.js", mpl_js, methods=["GET"]),
            Route("/custom.css", mpl_custom_css, methods=["GET"]),
            Route("/download.{fmt}", download, methods=["GET"]),
            WebSocketRoute("/ws", websocket_endpoint),
            Mount(
                "/_static",
                StaticFiles(
                    directory=FigureManagerWebAgg.get_static_file_path()  # type: ignore[no-untyped-call]
                ),
                name="mpl_static",
            ),
            Mount(
                "/_images",
                StaticFiles(directory=Path(mpl.get_data_path(), "images")),
                name="mpl_images",
            ),
        ],
    )


_app: Optional[Starlette] = None


def get_or_create_application(
    app_host: Optional[str] = None,
    free_port: Optional[int] = None,
    secure_host: Optional[bool] = None,
) -> Starlette:
    global _app

    # Thread-safe lazy restart logic
    with _server_manager._restart_lock:
        if _app is None or not _server_manager.is_running():
            if _app is not None:
                LOGGER.info(
                    "Matplotlib server appears to have died, restarting..."
                )
                _server_manager.stop()
                # Clear existing figure managers to prevent stale state
                figure_managers.figure_managers.clear()
                _app = None

            # Start new server
            _app = _server_manager.start(app_host, free_port, secure_host)

    return _app


def new_figure_manager_given_figure(
    num: int, figure: Union[Figure, SubFigure, Axes]
) -> Any:
    from matplotlib.backends.backend_webagg_core import (
        FigureCanvasWebAggCore,
        FigureManagerWebAgg as CoreFigureManagerWebAgg,
        NavigationToolbar2WebAgg as CoreNavigationToolbar2WebAgg,
    )

    class FigureManagerWebAgg(CoreFigureManagerWebAgg):
        _toolbar2_class = CoreNavigationToolbar2WebAgg  # type: ignore[assignment]

    class FigureCanvasWebAgg(FigureCanvasWebAggCore):
        manager_class = FigureManagerWebAgg  # type: ignore[assignment]

    canvas = FigureCanvasWebAgg(figure)  # type: ignore[no-untyped-call]
    manager = FigureManagerWebAgg(canvas, num)  # type: ignore[no-untyped-call]
    return manager


@mddoc
def interactive(figure: Union[Figure, SubFigure, Axes]) -> Html:
    """Render a matplotlib figure using an interactive viewer.

    The interactive viewer allows you to pan, zoom, and see plot coordinates
    on mouse hover.

    Example:
        ```python
        plt.plot([1, 2])
        # plt.gcf() gets the current figure
        mo.mpl.interactive(plt.gcf())
        ```

    Args:
        figure (matplotlib Figure or Axes): A matplotlib `Figure` or `Axes` object.

    Returns:
        Html: An interactive matplotlib figure as an `Html` object.
    """
    # We can't support interactive plots in Pyodide
    # since they require a WebSocket connection
    if is_pyodide():
        LOGGER.error(
            "Interactive plots are not supported in Pyodide/WebAssembly"
        )
        return as_html(figure)

    # No top-level imports of matplotlib, since it isn't a required
    # dependency
    from matplotlib.axes import Axes

    if isinstance(figure, Axes):
        maybe_figure = figure.get_figure()
        assert maybe_figure is not None, "Axes object does not have a Figure"
        figure = maybe_figure

    ctx = get_context()
    if not isinstance(ctx, KernelRuntimeContext):
        return as_html(figure)

    # Figure Manager, Any type because matplotlib doesn't have typings
    figure_manager = new_figure_manager_given_figure(id(figure), figure)

    # TODO(akshayka): Proxy this server through the marimo server to help with
    # deployment.
    app = get_or_create_application()
    port = app.state.port

    class CleanupHandle(CellLifecycleItem):
        def create(self, context: RuntimeContext) -> None:
            del context

        def dispose(self, context: RuntimeContext, deletion: bool) -> bool:
            del context
            del deletion
            figure_managers.remove(figure_manager)
            return True

    figure_managers.add(figure_manager)
    assert ctx.execution_context is not None
    ctx.cell_lifecycle_registry.add(CleanupHandle())
    ctx.stream.cell_id = ctx.execution_context.cell_id

    content = _template(str(figure_manager.num), port)

    return Html(
        h.iframe(
            srcdoc=html.escape(content),
            width="100%",
            height="550px",
            onload="__resizeIframe(this)",
        )
    )


html_content = """
<!DOCTYPE html>
<html lang="en">
  <head>
    <base href='%(base_url)s/' />
    <link rel="stylesheet" href="%(base_url)s/_static/css/page.css" type="text/css" />
    <link rel="stylesheet" href="%(base_url)s/_static/css/boilerplate.css" type="text/css" />
    <link rel="stylesheet" href="%(base_url)s/_static/css/fbm.css" type="text/css" />
    <link rel="stylesheet" href="%(base_url)s/_static/css/mpl.css" type="text/css" />
    <link rel="stylesheet" href="%(base_url)s/custom.css" type="text/css" />
    <script src="%(base_url)s/mpl.js"></script>

    <script>
      function ondownload(figure, format) {
        window.open('download.' + format + '?figure=' + figure.id, '_blank');
      };

      function ready(fn) {
        if (document.readyState != "loading") {
          fn();
        } else {
          document.addEventListener("DOMContentLoaded", fn);
        }
      }

      ready(
        function() {
          var websocket_type = mpl.get_websocket_type();
          var websocket = new websocket_type("%(ws_uri)s");

          // mpl.figure creates a new figure on the webpage.
          var fig = new mpl.figure(
              // A unique numeric identifier for the figure
              %(fig_id)s,
              // A websocket object
              websocket,
              // A function called when a file type is selected for download
              ondownload,
              // The HTML element in which to place the figure
              document.getElementById("figure"));
        }
      );
    </script>

    <title>marimo</title>
  </head>

  <body>
    <div id="figure"></div>
  </body>
</html>
""".strip()  # noqa: E501

# Custom CSS to make the mpl toolbar fit the marimo UI
# We do not support dark mode at the moment as the iframe does not know
# the theme of the parent page.
css_content = """
body {
    background-color: transparent;
    width: 100%;
}
#figure, mlp-canvas {
    width: 100%;
}
.ui-dialog-titlebar + div {
    border-radius: 4px;
}
.ui-dialog-titlebar {
    display: none;
}
.mpl-toolbar {
    display: flex;
    align-items: center;
    gap: 8px;
}
select.mpl-widget,
.mpl-button-group {
    margin: 4px 0;
    border-radius: 6px;
    box-shadow: rgba(0, 0, 0, 0) 0px 0px 0px 0px,
        rgba(0, 0, 0, 0) 0px 0px 0px 0px,
        rgba(15, 23, 42, 0.1) 1px 1px 0px 0px,
        rgba(15, 23, 42, 0.1) 0px 0px 2px 0px;
}
.mpl-button-group + .mpl-button-group {
    margin-left: 0;
}
.mpl-button-group > .mpl-widget {
    padding: 4px;
}
.mpl-button-group > .mpl-widget > img {
    height: 16px;
    width: 16px;
}
.mpl-widget:disabled, .mpl-widget[disabled]
.mpl-widget:disabled, .mpl-widget[disabled]:hover {
    opacity: 0.5;
    background-color: #fff;
    border-color: #ccc !important;
}
.mpl-message {
    color: rgb(139, 141, 152);
    font-size: 11px;
}
.mpl-widget img {
    filter: invert(0.3);
}
""".strip()


def patch_javascript(javascript: str) -> str:
    # Comment out canvas.focus() and canvas_div.focus() calls
    # https://github.com/matplotlib/matplotlib/blob/4c345b42048811a2122ba0db68551c6ea4ddaf6a/lib/matplotlib/backends/web_backend/js/mpl.js#L338-L343
    javascript = javascript.replace(
        " canvas.focus();",
        "// canvas.focus(); // don't steal focus when in marimo",
    )
    javascript = javascript.replace(
        " canvas_div.focus();",
        "// canvas_div.focus(); // don't steal focus when in marimo",
    )
    return javascript
