from typing import Optional, Union

from libcst import (
    Arg,
    Attribute,
    BaseExpression,
    BaseSmallStatement,
    BaseStatement,
    Call,
    FunctionDef,
    ImportFrom,
    MaybeSentinel,
    Name,
    RemovalSentinel,
    RemoveFromParent,
    Return,
)
from libcst import matchers as m
from libcst.codemod import ContextAwareTransformer
from libcst.codemod.visitors import AddImportsVisitor

from django_codemod.constants import DJANGO_19, DJANGO_20, DJANGO_21, DJANGO_111
from django_codemod.visitors.base import module_matcher


class ModelsPermalinkTransformer(ContextAwareTransformer):
    """Replaces ``@models.permalink`` decorator by a call to ``reverse()``."""

    deprecated_in = DJANGO_111
    removed_in = DJANGO_21
    ctx_key_prefix = "ModelsPermalinkTransformer"
    ctx_key_inside_method = f"{ctx_key_prefix}-inside_method"
    ctx_key_decorator_matchers = f"{ctx_key_prefix}-decorator_matchers"

    def leave_ImportFrom(
        self, original_node: ImportFrom, updated_node: ImportFrom
    ) -> Union[BaseSmallStatement, RemovalSentinel]:
        if m.matches(
            updated_node, m.ImportFrom(module=module_matcher(["django", "db"])),
        ):
            for imported_name in updated_node.names:
                if m.matches(imported_name, m.ImportAlias(name=m.Name("models"))):
                    self.add_decorator_matcher(
                        m.Decorator(
                            decorator=m.Attribute(
                                value=m.Name("models"), attr=m.Name("permalink")
                            )
                        )
                    )
        if m.matches(
            updated_node,
            m.ImportFrom(module=module_matcher(["django", "db", "models"])),
        ):
            updated_names = []
            for imported_name in updated_node.names:
                if m.matches(imported_name, m.ImportAlias(name=m.Name("permalink"))):
                    decorator_name = (
                        imported_name.asname.name.value
                        if imported_name.asname
                        else "permalink"
                    )
                    self.add_decorator_matcher(
                        m.Decorator(decorator=m.Name(decorator_name))
                    )
                else:
                    updated_names.append(imported_name)
            if not updated_names:
                return RemoveFromParent()
            # sort imports
            new_names = sorted(updated_names, key=lambda n: n.evaluated_name)
            # remove any trailing commas
            last_name = new_names[-1]
            if last_name.comma != MaybeSentinel.DEFAULT:
                new_names[-1] = last_name.with_changes(comma=MaybeSentinel.DEFAULT)
            return updated_node.with_changes(names=new_names)
        return super().leave_ImportFrom(original_node, updated_node)

    def add_decorator_matcher(self, matcher):
        if self.ctx_key_decorator_matchers not in self.context.scratch:
            self.context.scratch[self.ctx_key_decorator_matchers] = []
        self.context.scratch[self.ctx_key_decorator_matchers].append(matcher)

    @property
    def decorator_matcher(self):
        matchers_list = self.context.scratch.get(self.ctx_key_decorator_matchers, [])
        if len(matchers_list) == 0:
            return None
        if len(matchers_list) == 1:
            return matchers_list[0]
        return m.OneOf(*[matcher for matcher in matchers_list])

    def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
        for decorator in node.decorators:
            if m.matches(decorator, self.decorator_matcher):
                self.context.scratch[self.ctx_key_inside_method] = True
        return super().visit_FunctionDef(node)

    def leave_FunctionDef(
        self, original_node: FunctionDef, updated_node: FunctionDef
    ) -> Union[BaseStatement, RemovalSentinel]:
        if self.visiting_permalink_method:
            for decorator in updated_node.decorators:
                if m.matches(decorator, self.decorator_matcher):
                    AddImportsVisitor.add_needed_import(
                        context=self.context, module="django.urls", obj="reverse",
                    )
                    updated_decorators = list(updated_node.decorators)
                    updated_decorators.remove(decorator)
                    self.context.scratch.pop(self.ctx_key_inside_method, None)
                    return updated_node.with_changes(
                        decorators=tuple(updated_decorators)
                    )
        return super().leave_FunctionDef(original_node, updated_node)

    @property
    def visiting_permalink_method(self):
        return self.context.scratch.get(self.ctx_key_inside_method, False)

    def leave_Return(
        self, original_node: Return, updated_node: Return
    ) -> Union[BaseSmallStatement, RemovalSentinel]:
        if self.visiting_permalink_method and m.matches(updated_node.value, m.Tuple()):
            elem_0 = updated_node.value.elements[0]
            elem_1_3 = updated_node.value.elements[1:3]
            args = (
                Arg(elem_0.value),
                Arg(Name("None")),
                *[Arg(el.value) for el in elem_1_3],
            )
            return updated_node.with_changes(
                value=Call(func=Name("reverse"), args=args)
            )
        return super().leave_Return(original_node, updated_node)


def is_foreign_key(node: Call) -> bool:
    return m.matches(node, m.Call(func=m.Attribute(attr=m.Name(value="ForeignKey"))))


def is_one_to_one_field(node: Call) -> bool:
    return m.matches(
        node, m.Call(func=m.Attribute(attr=m.Name(value="OneToOneField"))),
    )


def has_on_delete(node: Call) -> bool:
    # if on_delete exists in any kwarg we return True
    for arg in node.args:
        if m.matches(arg, m.Arg(keyword=m.Name("on_delete"))):
            return True

    # if there are two or more nodes and there are no keywords
    # then we can assume that positional arguments are being used
    # and on_delete is being handled.
    return len(node.args) >= 2 and node.args[1].keyword is None


class OnDeleteTransformer(ContextAwareTransformer):
    deprecated_in = DJANGO_19
    removed_in = DJANGO_20
    ctx_key_prefix = "OnDeleteTransformer"

    def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
        if (
            is_one_to_one_field(original_node) or is_foreign_key(original_node)
        ) and not has_on_delete(original_node):
            AddImportsVisitor.add_needed_import(
                context=self.context, module="django.db", obj="models",
            )
            updated_args = (
                *updated_node.args,
                Arg(
                    keyword=Name("on_delete"),
                    value=Attribute(value=Name("models"), attr=Name("CASCADE")),
                ),
            )
            return updated_node.with_changes(args=updated_args)
        return super().leave_Call(original_node, updated_node)
