"""Abstract base class for API-type streams."""

from __future__ import annotations

import abc
import copy
import decimal
import logging
import typing as t
from functools import cached_property
from http import HTTPStatus
from urllib.parse import urlparse
from warnings import warn

import backoff
import requests
import requests.exceptions

from singer_sdk import metrics
from singer_sdk.authenticators import SimpleAuthenticator
from singer_sdk.exceptions import FatalAPIError, RetriableAPIError
from singer_sdk.helpers._compat import SingerSDKDeprecationWarning
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import (
    JSONPathPaginator,
    LegacyStreamPaginator,
    SimpleHeaderPaginator,
    SinglePagePaginator,
)
from singer_sdk.streams.core import Stream

if t.TYPE_CHECKING:
    from collections.abc import Iterable, Mapping
    from datetime import datetime

    from backoff.types import Details

    from singer_sdk.helpers.types import Auth, Context, RequestFunc
    from singer_sdk.pagination import BaseAPIPaginator
    from singer_sdk.singerlib import Schema
    from singer_sdk.tap_base import Tap

DEFAULT_PAGE_SIZE = 1000
DEFAULT_REQUEST_TIMEOUT = 300  # 5 minutes

_TToken = t.TypeVar("_TToken")
_TNum = t.TypeVar("_TNum", int, float)


class _HTTPStream(Stream, t.Generic[_TToken], metaclass=abc.ABCMeta):  # noqa: PLR0904
    """Abstract base class for HTTP streams."""

    _page_size: int = DEFAULT_PAGE_SIZE
    _requests_session: requests.Session | None

    #: Response code reference for rate limit retries
    extra_retry_statuses: t.Sequence[int] = [HTTPStatus.TOO_MANY_REQUESTS]

    #: Optional flag to disable HTTP redirects. Defaults to False.
    allow_redirects: bool = True

    #: Set this to True if the API expects a JSON payload in the request body.
    payload_as_json: bool = False

    # Private constants. May not be supported in future releases:
    _LOG_REQUEST_METRICS: bool = True
    # Disabled by default for safety:
    _LOG_REQUEST_METRIC_URLS: bool = False

    @property
    @abc.abstractmethod
    def url_base(self) -> str:
        """The base request URL, e.g. ``https://api.mysite.com/v3/``.

        Request URLs are generated by combining `url_base` and `path`, and expanding any
        context variables in the path.

        For example, if ``url_base`` is ``https://api.mysite.com/v3/`` and ``path`` is
        ``users/{user_id}/orders``, then if the stream has a context of
        ``{"user_id": 123}`` generated by its parent stream with
        :meth:`~singer_sdk.Stream.generate_child_contexts`, the full URL will be
        ``https://api.mysite.com/v3/users/123/orders``.
        """

    def __init__(
        self,
        tap: Tap,
        name: str | None = None,
        schema: dict[str, t.Any] | Schema | None = None,
        path: str | None = None,
        *,
        http_method: str | None = None,
    ) -> None:
        """Initialize the HTTP stream.

        Args:
            tap: Singer Tap this stream belongs to.
            schema: JSON schema for records in this stream.
            name: Name of this stream.
            path: URL path for this entity stream.
            http_method: HTTP method to use for requests.
        """
        if path:
            self.path = path
        self._http_headers: dict[str, str] = {}
        self._http_method = http_method
        self._requests_session = requests.Session()
        super().__init__(name=name, schema=schema, tap=tap)

    @staticmethod
    def _url_encode(val: str | datetime | bool | int | list[str]) -> str:  # noqa: FBT001
        """Encode the val argument as url-compatible string.

        Args:
            val: TODO

        Returns:
            TODO
        """
        return val.replace("/", "%2F") if isinstance(val, str) else str(val)

    def get_url(self, context: Context | None) -> str:
        """Get stream entity URL.

        Developers override this method to perform dynamic URL generation.

        Args:
            context: Stream partition or context dictionary.

        Returns:
            A URL, optionally targeted to a specific partition or context.
        """
        url = "".join([self.url_base, self.path or ""])
        vals = copy.copy(dict(self.config))
        vals.update(context or {})
        for k, v in vals.items():
            search_text = f"{{{k}}}"
            if search_text in url:
                url = url.replace(search_text, self._url_encode(v))
        return url

    # HTTP Request functions

    @property
    def http_headers(self) -> dict:
        """Return headers dict to be used for HTTP requests.

        If an authenticator is also specified, the authenticator's headers will be
        combined with `http_headers` when making HTTP requests.

        Returns:
            Dictionary of HTTP headers to use as a base for every request.
        """
        return {
            "User-Agent": self.user_agent,
            **self._http_headers,
        }

    @property
    def http_method(self) -> str:
        """HTTP method to use for requests. Defaults to "GET"."""
        if self._http_method:
            return self._http_method

        if hasattr(self, "rest_method"):
            warn(
                "Use `http_method` instead.",
                SingerSDKDeprecationWarning,
                stacklevel=2,
            )
            return self.rest_method  # type: ignore[no-any-return]

        return "GET"

    @http_method.setter
    def http_method(self, value: str) -> None:
        """Set the HTTP method for requests.

        Args:
            value: The HTTP method to use for requests.
        """
        self._http_method = value

    @property
    def requests_session(self) -> requests.Session:
        """Get requests session.

        Returns:
            The :class:`requests.Session` object for HTTP requests.
        """
        if not self._requests_session:
            self._requests_session = requests.Session()
        return self._requests_session

    @cached_property
    def user_agent(self) -> str:
        """Get the user agent string for the stream.

        Returns:
            The user agent string.

        .. versionadded:: 0.40.0
        """
        return self.config.get(  # type: ignore[no-any-return]
            "user_agent",
            f"{self.tap_name}/{self._tap.plugin_version}",
        )

    def validate_response(self, response: requests.Response) -> None:
        """Validate HTTP response.

        Checks for error status codes and whether they are fatal or retriable.

        In case an error is deemed transient and can be safely retried, then this
        method should raise an :class:`singer_sdk.exceptions.RetriableAPIError`.
        By default this applies to 5xx error codes, along with values set in:
        :attr:`~singer_sdk.RESTStream.extra_retry_statuses`

        In case an error is unrecoverable raises a
        :class:`singer_sdk.exceptions.FatalAPIError`. By default, this applies to
        4xx errors, excluding values found in:
        :attr:`~singer_sdk.RESTStream.extra_retry_statuses`

        Tap developers are encouraged to override this method if their APIs use HTTP
        status codes in non-conventional ways, or if they communicate errors
        differently (e.g. in the response body).

        .. image:: ../images/200.png

        Args:
            response: A :class:`requests.Response` object.

        Raises:
            FatalAPIError: If the request is not retriable.
            RetriableAPIError: If the request is retriable.
        """
        if (
            response.status_code in self.extra_retry_statuses
            or response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR
        ):
            msg = self.response_error_message(response)
            raise RetriableAPIError(msg, response)

        if (
            HTTPStatus.BAD_REQUEST
            <= response.status_code
            < HTTPStatus.INTERNAL_SERVER_ERROR
        ):
            msg = self.response_error_message(response)
            raise FatalAPIError(msg)

    def response_error_message(self, response: requests.Response) -> str:
        """Build error message for invalid http statuses.

        WARNING - Override this method when the URL path may contain secrets or PII

        Args:
            response: A :class:`requests.Response` object.

        Returns:
            str: The error message
        """
        full_path = urlparse(response.url).path or self.path
        error_type = (
            "Client"
            if HTTPStatus.BAD_REQUEST
            <= response.status_code
            < HTTPStatus.INTERNAL_SERVER_ERROR
            else "Server"
        )

        msg = (
            f"{response.status_code} {error_type} Error: "
            f"{response.reason} for path: {full_path}"
        )

        if response.content:
            msg += f", content is {response.text}"

        if response.headers:
            msg += f", headers: {','.join(response.headers.keys())}"

        return msg

    def request_decorator(self, func: RequestFunc) -> RequestFunc:
        """Instantiate a decorator for handling request failures.

        Uses a wait generator defined in `backoff_wait_generator` to
        determine backoff behaviour. Try limit is defined in
        `backoff_max_tries`, and will trigger the event defined in
        `backoff_handler` before retrying. Developers may override one or
        all of these methods to provide custom backoff or retry handling.

        Args:
            func: Function to decorate.

        Returns:
            A decorated method.
        """
        decorator: t.Callable = backoff.on_exception(
            self.backoff_wait_generator,
            (
                ConnectionResetError,
                RetriableAPIError,
                requests.exceptions.Timeout,
                requests.exceptions.ConnectionError,
                requests.exceptions.ChunkedEncodingError,
                requests.exceptions.ContentDecodingError,
            ),
            max_tries=self.backoff_max_tries,
            on_backoff=self.backoff_handler,
            jitter=self.backoff_jitter,
            logger=self.logger,
        )(func)
        return decorator

    def _request(
        self,
        prepared_request: requests.PreparedRequest,
        context: Context | None,
    ) -> requests.Response:
        """TODO.

        Args:
            prepared_request: TODO
            context: Stream partition or context dictionary.

        Returns:
            TODO
        """
        authenticated_request = self.authenticator(prepared_request)
        response = self.requests_session.send(
            authenticated_request,
            timeout=self.timeout,
            allow_redirects=self.allow_redirects,
        )
        self._write_request_duration_log(
            endpoint=self.path,
            response=response,
            context=context,
            extra_tags={"url": authenticated_request.path_url}
            if self._LOG_REQUEST_METRIC_URLS
            else None,
        )
        self.validate_response(response)
        return response

    def get_url_params(  # noqa: PLR6301
        self,
        context: Context | None,  # noqa: ARG002
        next_page_token: _TToken | None,  # noqa: ARG002
    ) -> dict[str, t.Any] | str:
        """Return a dictionary or string of URL query parameters.

        If paging is supported, developers may override with specific paging logic.

        If your source needs special handling and, for example, parentheses should not
        be encoded, you can return a string constructed with
        :py:func:`urllib.parse.urlencode`:

        .. code-block:: python

           from urllib.parse import urlencode


           class MyStream(RESTStream):
               def get_url_params(self, context, next_page_token):
                   params = {"key": "(a,b,c)"}
                   return urlencode(params, safe="()")

        Args:
            context: Stream partition or context dictionary.
            next_page_token: Token, page number or any request argument to request the
                next page of data.

        Returns:
            Dictionary or encoded string with URL query parameters to use in the
                request.
        """
        return {}

    def build_prepared_request(
        self,
        *args: t.Any,
        **kwargs: t.Any,
    ) -> requests.PreparedRequest:
        """Build a generic but authenticated request.

        Uses the authenticator instance to mutate the request with authentication.

        Args:
            *args: Arguments to pass to :class:`requests.Request`.
            **kwargs: Keyword arguments to pass to :class:`requests.Request`.

        Returns:
            A :class:`requests.PreparedRequest` object.
        """
        request = requests.Request(*args, **kwargs)
        self.requests_session.auth = self.authenticator
        return self.requests_session.prepare_request(request)

    def prepare_request(
        self,
        context: Context | None,
        next_page_token: _TToken | None,
    ) -> requests.PreparedRequest:
        """Prepare a request object for this stream.

        If partitioning is supported, the `context` object will contain the partition
        definitions. Pagination information can be parsed from `next_page_token` if
        `next_page_token` is not None.

        Args:
            context: Stream partition or context dictionary.
            next_page_token: Token, page number or any request argument to request the
                next page of data.

        Returns:
            Build a request with the stream's URL, path, query parameters,
            HTTP headers and authenticator.
        """
        http_method = self.http_method
        url: str = self.get_url(context)
        params: dict | str = self.get_url_params(context, next_page_token)
        request_data = self.prepare_request_payload(context, next_page_token)
        headers = self.http_headers

        prepare_kwargs: dict[str, t.Any] = {
            "method": http_method,
            "url": url,
            "params": params,
            "headers": headers,
        }

        if self.payload_as_json:
            prepare_kwargs["json"] = request_data
        else:
            prepare_kwargs["data"] = request_data

        return self.build_prepared_request(**prepare_kwargs)

    def get_http_request_counter(self) -> metrics.Counter:
        """Get the HTTP request counter for the stream.

        Returns:
            The HTTP request counter for the stream.

        .. versionadded:: 0.51.0
        """
        return metrics.http_request_counter(self.name, endpoint=self.path)

    def request_records(self, context: Context | None) -> t.Iterable[dict]:
        """Request records from REST endpoint(s), returning response records.

        If pagination is detected, pages will be recursed automatically.

        Args:
            context: Stream partition or context dictionary.

        Yields:
            An item for every record in the response.
        """
        paginator = self.get_new_paginator() or SinglePagePaginator()
        decorated_request = self.request_decorator(self._request)
        pages = 0

        with self.get_http_request_counter() as request_counter:
            request_counter.with_context(context)

            while not paginator.finished:
                prepared_request = self.prepare_request(
                    context,
                    next_page_token=paginator.current_value,
                )
                resp = decorated_request(prepared_request, context)
                request_counter.increment()
                self.update_sync_costs(prepared_request, resp, context)
                records = iter(self.parse_response(resp))
                try:
                    first_record = next(records)
                except StopIteration:
                    if paginator.continue_if_empty(resp):
                        paginator.advance(resp)
                        continue

                    self.log(
                        "Pagination stopped after %d pages because no records were "
                        "found in the last response",
                        pages,
                    )
                    break
                yield first_record
                yield from records
                pages += 1

                paginator.advance(resp)

    def _write_request_duration_log(
        self,
        endpoint: str,
        response: requests.Response,
        context: Context | None,
        extra_tags: dict | None,
    ) -> None:
        """TODO.

        Args:
            endpoint: The endpoint of the request.
            response: The response object.
            context: Stream partition or context dictionary.
            extra_tags: A dictionary of extra tags to add to the metric.
        """
        if not self._LOG_REQUEST_METRICS:
            return

        extra_tags = extra_tags or {}
        if context:
            extra_tags[metrics.Tag.CONTEXT] = context

        point = metrics.Point(
            "timer",
            metric=metrics.Metric.HTTP_REQUEST_DURATION,
            value=response.elapsed.total_seconds(),
            tags={
                metrics.Tag.STREAM: self.name,
                metrics.Tag.ENDPOINT: endpoint,
                metrics.Tag.HTTP_STATUS_CODE: response.status_code,
                metrics.Tag.STATUS: (
                    metrics.Status.SUCCEEDED
                    if response.status_code < HTTPStatus.BAD_REQUEST
                    else metrics.Status.FAILED
                ),
                **extra_tags,
            },
        )
        self._log_metric(point)

    def update_sync_costs(
        self,
        request: requests.PreparedRequest,
        response: requests.Response,
        context: Context | None,
    ) -> dict[str, int]:
        """Update internal calculation of Sync costs.

        Args:
            request: the Request object that was just called.
            response: the :class:`requests.Response` object
            context: the context passed to the call

        Returns:
            A dict of costs (for the single request) whose keys are
            the "cost domains". See `calculate_sync_cost` for details.
        """
        call_costs = self.calculate_sync_cost(request, response, context)
        self._sync_costs = {
            k: self._sync_costs.get(k, 0) + call_costs.get(k, 0) for k in call_costs
        }
        return self._sync_costs

    # Overridable:

    def calculate_sync_cost(  # noqa: PLR6301
        self,
        request: requests.PreparedRequest,  # noqa: ARG002
        response: requests.Response,  # noqa: ARG002
        context: Context | None,  # noqa: ARG002
    ) -> dict[str, int]:
        """Calculate the cost of the last API call made.

        This method can optionally be implemented in streams to calculate
        the costs (in arbitrary units to be defined by the tap developer)
        associated with a single API/network call. The request and response objects
        are available in the callback, as well as the context.

        The method returns a dict where the keys are arbitrary cost dimensions,
        and the values the cost along each dimension for this one call. For
        instance: { "rest": 0, "graphql": 42 } for a call to github's graphql API.
        All keys should be present in the dict.

        This method can be overridden by tap streams. By default it won't do
        anything.

        Args:
            request: the API Request object that was just called.
            response: the :class:`requests.Response` object
            context: the context passed to the call

        Returns:
            A dict of accumulated costs whose keys are the "cost domains".
        """
        return {}

    def prepare_request_payload(
        self,
        context: Context | None,
        next_page_token: _TToken | None,
    ) -> (
        Iterable[bytes]
        | str
        | bytes
        | list[tuple[t.Any, t.Any]]
        | tuple[tuple[t.Any, t.Any]]
        | Mapping[str, t.Any]
        | None
    ):
        """Prepare the data payload for the HTTP request.

        By default, no payload will be sent (return None).

        Developers may override this method if the API requires a custom payload along
        with the request. (This is generally not required for APIs which use the
        HTTP 'GET' method.)

        Args:
            context: Stream partition or context dictionary.
            next_page_token: Token, page number or any request argument to request the
                next page of data.
        """

    @property
    def timeout(self) -> int:
        """Return the request timeout limit in seconds.

        The default timeout is 300 seconds, or as defined by DEFAULT_REQUEST_TIMEOUT.

        Returns:
            The request timeout limit as number of seconds.
        """
        return DEFAULT_REQUEST_TIMEOUT

    # Records iterator

    def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]:
        """Return a generator of record-type dictionary objects.

        Each record emitted should be a dictionary of property names to their values.

        Args:
            context: Stream partition or context dictionary.

        Yields:
            One item per (possibly processed) record in the API.
        """
        yield from self.request_records(context)

    # Abstract methods:

    @abc.abstractmethod
    def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
        """Parse the response and return an iterator of result records.

        Args:
            response: A raw :class:`requests.Response`

        Yields:
            One item for every item found in the response.
        """
        ...

    @abc.abstractmethod
    def get_new_paginator(self) -> BaseAPIPaginator | None:
        """Get a fresh paginator for this endpoint.

        Returns:
            A paginator instance, or ``None`` to indicate pagination is not supported.
        """
        ...

    @property
    def authenticator(self) -> Auth:
        """Return or set the authenticator for managing HTTP auth headers.

        If an authenticator is not specified, REST-based taps will simply pass
        `http_headers` as defined in the stream class.

        Returns:
            Authenticator instance that will be used to authenticate all outgoing
            requests.
        """
        return SimpleAuthenticator()

    def backoff_wait_generator(self) -> t.Generator[float, None, None]:  # noqa: PLR6301
        """The wait generator used by the backoff decorator on request failure.

        See for options:
        https://github.com/litl/backoff/blob/master/backoff/_wait_gen.py

        And see for examples: `Code Samples <../code_samples.html#custom-backoff>`_

        Returns:
            The wait generator
        """
        return backoff.expo(factor=2)

    def backoff_max_tries(self) -> int:  # noqa: PLR6301
        """The number of attempts before giving up when retrying requests.

        Returns:
            Number of max retries.
        """
        return 5

    def backoff_jitter(self, value: float) -> float:  # noqa: PLR6301
        """Amount of jitter to add.

        For more information see
        https://github.com/litl/backoff/blob/master/backoff/_jitter.py

        We chose to default to ``random_jitter`` instead of ``full_jitter`` as we keep
        some level of default jitter to be "nice" to downstream APIs but it's still
        relatively close to the default value that's passed in to make tap developers'
        life easier.

        Args:
            value: Base amount to wait in seconds

        Returns:
            Time in seconds to wait until the next request.
        """
        return backoff.random_jitter(value)

    def backoff_handler(self, details: Details) -> None:
        """Adds additional behaviour prior to retry.

        By default will log out backoff details, developers can override
        to extend or change this behaviour.

        Args:
            details: backoff invocation details
                https://github.com/litl/backoff#event-handlers
        """
        if (
            (exc := details.get("exception"))
            and isinstance(exc, RetriableAPIError)
            and exc.response is not None
        ):
            self.log(
                "Backing off %0.2f seconds after %d tries "
                "for URL %s, failing with status %s: %s",
                details.get("wait"),
                details.get("tries"),
                self.path,
                exc.response.status_code,
                exc.response.reason,
                level=logging.ERROR,
            )
            return

        self.log(
            "Backing off %0.2f seconds after %d tries "
            "calling function %s with args %s and kwargs "
            "%s",
            details.get("wait"),
            details.get("tries"),
            details.get("target"),
            details.get("args"),
            details.get("kwargs"),
            level=logging.ERROR,
        )

    def backoff_runtime(  # noqa: PLR6301
        self,
        *,
        value: t.Callable[[t.Any], _TNum],
    ) -> t.Generator[_TNum, None, None]:
        """Optional backoff wait generator that can replace the default `backoff.expo`.

        It is based on parsing the thrown exception of the decorated method, making it
        possible for response values to be in scope.

        You may want to review :meth:`~singer_sdk.RESTStream.backoff_jitter` if you're
        overriding this function.

        Args:
            value: a callable which takes as input the decorated
                function's thrown exception and determines how
                long to wait.

        Yields:
            The thrown exception
        """
        exception = yield  # type: ignore[misc]
        while True:
            exception = yield value(exception)


class RESTStream(_HTTPStream, t.Generic[_TToken], metaclass=abc.ABCMeta):
    """Abstract base class for REST API streams."""

    #: Optional JSONPath expression to extract a pagination token from the API response.
    #: Example: `"$.next_page"`
    next_page_token_jsonpath: str | None = None

    payload_as_json: bool = True
    """Set this to False if the API expects something other than JSON in the request
    body.

    .. versionadded:: 0.43.0
    """

    def __init__(
        self,
        tap: Tap,
        name: str | None = None,
        schema: dict[str, t.Any] | Schema | None = None,
        path: str | None = None,
        *,
        http_method: str | None = None,
    ) -> None:
        """Initialize the REST stream.

        Args:
            tap: Singer Tap this stream belongs to.
            schema: JSON schema for records in this stream.
            name: Name of this stream.
            path: URL path for this entity stream.
            http_method: HTTP method to use for requests
        """
        super().__init__(tap, name, schema, path, http_method=http_method)
        self._compiled_jsonpath = None
        self._next_page_token_compiled_jsonpath = None

    @property
    def records_jsonpath(self) -> str:
        """JSONPath expression to extract records from the API response."""
        return "$[*]"

    def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
        """Parse the response and return an iterator of result records.

        Args:
            response: A raw :class:`requests.Response`

        Yields:
            One item for every item found in the response.
        """
        yield from extract_jsonpath(
            self.records_jsonpath,
            input=response.json(parse_float=decimal.Decimal),
        )

    def get_new_paginator(self) -> BaseAPIPaginator | None:
        """Get a fresh paginator for this API endpoint.

        Returns:
            A paginator instance, or ``None`` to indicate pagination is not supported.
        """
        if hasattr(self, "get_next_page_token"):
            warn(
                "`RESTStream.get_next_page_token` is deprecated and will not be used "
                "in a future version of the Meltano Singer SDK. "
                "Override `RESTStream.get_new_paginator` instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            return LegacyStreamPaginator(self)

        if self.next_page_token_jsonpath:
            return JSONPathPaginator(self.next_page_token_jsonpath)

        return SimpleHeaderPaginator("X-Next-Page")
