import inspect
import json
import logging
import random
import resource
import subprocess
import sys
import warnings
from functools import wraps
from itertools import chain, islice
from typing import Any, Callable, Dict, Iterable, Iterator, List, Sequence, TypeVar

import numpy
import torch

logger = logging.getLogger(__name__)

A = TypeVar("A")


def namespace_match(pattern: str, namespace: str) -> bool:
    """
    Matches a namespace pattern against a namespace string.  For example, ``*tags`` matches
    ``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
    ``stemmed_tokens``.
    """
    if pattern[0] == "*" and namespace.endswith(pattern[1:]):
        return True
    elif pattern == namespace:
        return True
    return False


def flatten_filename(file_path: str) -> str:
    return file_path.replace("/", "_SLASH_")


def ensure_list(iterable: Iterable[A]) -> List[A]:
    """
    An Iterable may be a list or a generator.
    This ensures we get a list without making an unnecessary copy.
    """
    if isinstance(iterable, list):
        return iterable
    return list(iterable)


def lazy_groups_of(iterator: Iterator[A], group_size: int) -> Iterator[List[A]]:
    """
    Takes an iterators and batches the individual instances into lists of the
    specified size. The last list may be smaller if there are instances left over.
    """
    return iter(lambda: list(islice(iterator, 0, group_size)), [])


def is_lazy(iterable: Iterable[A]) -> bool:
    """
    Checks if the given iterable is lazy,
    which here just means it's not a list.
    """
    return not isinstance(iterable, list)


def dump_metrics(file_path: str, metrics: Dict[str, Any], log: bool = False) -> None:
    metrics_json = json.dumps(metrics, indent=2)
    with open(file_path, "w") as metrics_file:
        metrics_file.write(metrics_json)
    if log:
        logger.info("Metrics: %s", metrics_json)


def gpu_memory_mb() -> Dict[int, int]:
    """
    Get the current GPU memory usage.
    Based on https://discuss.pytorch.org/t/access-gpu-memory-usage-in-pytorch/3192/4
    Returns
    -------
    ``Dict[int, int]``
        Keys are device ids as integers.
        Values are memory usage as integers in MB.
        Returns an empty ``dict`` if GPUs are not available.
    """
    # pylint: disable=bare-except
    try:
        result = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"], encoding="utf-8"
        )
        gpu_memory = [int(x) for x in result.strip().split("\n")]
        return {gpu: memory for gpu, memory in enumerate(gpu_memory)}
    except FileNotFoundError:
        # `nvidia-smi` doesn't exist, assume that means no GPU.
        return {}
    except Exception:
        # Catch *all* exceptions, because this memory check is a nice-to-have
        # and we'd never want a training run to fail because of it.
        logger.exception("unable to check gpu_memory_mb(), continuing")
        return {}


def peak_memory_mb() -> float:
    """
    Get peak memory usage for this process, as measured by
    max-resident-set size:
    https://unix.stackexchange.com/questions/30940/getrusage-system-call-what-is-maximum-resident-set-size
    Only works on OSX and Linux, returns 0.0 otherwise.
    """
    if resource is None or sys.platform not in ("linux", "darwin"):
        return 0.0

    peak = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss  # type: ignore

    if sys.platform == "darwin":
        # On OSX the result is in bytes.
        return peak / 1_000_000

    return peak / 1_000


def sanitize(x: Any) -> Any:  # pylint: disable=invalid-name,too-many-return-statements
    """
    Sanitize turns PyTorch and Numpy types into basic Python types so they
    can be serialized into JSON.
    """
    if isinstance(x, (str, float, int, bool)):
        # x is already serializable
        return x
    elif isinstance(x, torch.Tensor):
        # tensor needs to be converted to a list (and moved to cpu if necessary)
        return x.cpu().tolist()
    elif isinstance(x, numpy.ndarray):
        # array needs to be converted to a list
        return x.tolist()
    elif isinstance(x, numpy.number):  # pylint: disable=no-member
        # NumPy numbers need to be converted to Python numbers
        return x.item()
    elif isinstance(x, dict):
        # Dicts need their values sanitized
        return {key: sanitize(value) for key, value in x.items()}
    elif isinstance(x, (list, tuple)):
        # Lists and Tuples need their values sanitized
        return [sanitize(x_i) for x_i in x]
    elif x is None:
        return "None"
    elif hasattr(x, "to_json"):
        return x.to_json()
    else:
        raise ValueError(
            f"Cannot sanitize {x} of type {type(x)}. "
            "If this is your own custom class, add a `to_json(self)` method "
            "that returns a JSON-like object."
        )


def tensor_to_ohe(x: torch.Tensor, num_classes: int) -> torch.Tensor:
    assert len(x.size()) == 1
    reshaped_x = x.reshape(x.size(0), 1)
    zero_matrix = torch.zeros((x.size(0), num_classes))
    return zero_matrix.scatter_(1, reshaped_x, 1)


def set_seed(seed: int, n_gpu: int) -> None:
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)  # type: ignore
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)  # type: ignore


def deprecated_arguments(**renaming: str) -> Callable:  # type: ignore
    def decorator(func: Callable[[Any], Any]) -> Callable:  # type: ignore
        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Callable:  # type: ignore
            func_parameters = dict(inspect.signature(func).parameters)
            if "cls" not in renaming and "cls" in func_parameters:
                method_arg = ["cls"]
                func_parameters.pop("cls")
            elif "self" not in renaming and "self" in func_parameters:
                method_arg = ["self"]
                func_parameters.pop("self")
            else:
                method_arg = []
            if args and not method_arg or len(args) > 1:
                warnings.warn(
                    "Consider providing only keyword arguments, as this method changed "
                    "signature and automated remapping from old one (if old signature was used) "
                    "can lead to unexpected behavior"
                )
            if all(kwarg in renaming.values() for kwarg in kwargs):
                return func(*args, **kwargs)  # type: ignore
            updated_kwargs = {}
            new_parameters_not_in_renaming = [
                (parameter_name, parameter)
                for parameter_name, parameter in func_parameters.items()
                if parameter_name not in renaming.values()
            ]
            for new_parameter_name, parameter in new_parameters_not_in_renaming:
                if parameter.default == inspect._empty:  # type: ignore
                    raise ValueError(
                        f"{new_parameter_name} not in 'renaming' and it doesn't have default value. "
                        "It's impossible to set it when user will provide old argument names"
                    )
                renaming[new_parameter_name] = new_parameter_name
            for new_parameter_name, arg in zip(chain(method_arg, renaming.values()), args):
                updated_kwargs[new_parameter_name] = arg
            for old_parameter_name, arg in kwargs.items():
                updated_kwargs[renaming[old_parameter_name]] = arg
            return func(**updated_kwargs)  # type: ignore

        return wrapper

    return decorator


def argsort(seq: Sequence[Any]) -> List[int]:
    return sorted(range(len(seq)), key=seq.__getitem__)
