import enum
import importlib.util
import inspect
import logging
import os
import sys
import typing
from os.path import relpath
from pathlib import Path

from typedpy.commons import doublewrap_val
from typedpy.fields import FunctionCall
from typedpy.serialization.serialization_wrappers import Deserializer, Serializer
from typedpy.structures import (
    Field,
    ImmutableStructure,
    Structure,
)
from typedpy.stubs.type_info_getter import (
    get_all_type_info,
    get_type_info,
)
from typedpy.stubs.types_ast import (
    functions_to_str,
    get_imports,
    get_models,
    models_to_src,
)
from .methods_info_getter import (
    get_methods_and_attributes_as_code,
    get_init,
    get_additional_structure_methods,
)

from .function_info_getter import get_stubs_of_functions
from .utils import (
    is_sqlalchemy,
    INDENT,
    get_package,
    as_something,
    is_internal_sqlalchemy,
)

module = getattr(inspect, "__class__")


AUTOGEN_NOTE = [
    "",
    "#### This stub was autogenerated by Typedpy",
    "###########################################",
    "",
]


def _get_struct_classes(attrs):
    return {
        k: v
        for k, v in attrs.items()
        if (
            inspect.isclass(v)
            and issubclass(v, Structure)
            and (v.__module__ == attrs["__name__"])
            and v not in {Deserializer, Serializer, ImmutableStructure, FunctionCall}
        )
    }


def _get_imported_classes(attrs):
    def _valid_module(v):
        return hasattr(v, "__module__") and attrs["__name__"] != v.__module__

    res = []
    for k, v in attrs.items():
        if (
            not k.startswith("_")
            and (isinstance(v, module) or _valid_module(v))
            and not is_internal_sqlalchemy(v)
        ):
            if isinstance(v, module):
                if v.__name__ != k:
                    parts = v.__name__.split(".")
                    if len(parts) > 1:
                        first_parts = parts[:-1]
                        last_part = parts[-1]
                        res.append(
                            (
                                k,
                                f"from {'.'.join(first_parts)} import {last_part} as {k}",
                            )
                        )
                    elif (
                        getattr(v.os, k, None) == v
                    ):  # pylint: disable=(cell-var-from-loop
                        res.append((k, f"from os import {k} as {k}"))
                else:
                    res.append((k, f"import {k}"))
            else:
                pkg = get_package(v.__module__, attrs)

                name_in_pkg = getattr(v, "__name__", k)
                name_to_import = name_in_pkg if name_in_pkg and name_in_pkg != k else k
                as_import = (
                    as_something(k, attrs) if name_to_import == k else f" as {k}"
                )
                import_stmt = f"from {pkg} import {name_to_import}{as_import}"
                if pkg == "__future__":
                    res = [(k, import_stmt)] + res
                else:
                    res.append((k, import_stmt))
    return res


def _get_ordered_args(unordered_args: dict):
    optional_args = {k: v for k, v in unordered_args.items() if v.endswith("= None")}
    mandatory_args = {k: v for k, v in unordered_args.items() if k not in optional_args}
    return {**mandatory_args, **optional_args}


def _get_mapped_extra_imports(additional_imports) -> dict:
    mapped = {}
    for c in additional_imports:
        try:
            name = get_type_info(c, {}, set())
            if inspect.isclass(c) and issubclass(c, Structure):
                module_name = c.__module__
            else:
                module_name = (
                    c.get_type.__module__
                    if isinstance(c, Field)
                    else c.__module__
                    if name != "Any"
                    else None
                )
            if module_name:
                mapped[name] = module_name
        except Exception as e:
            logging.exception(e)
    return mapped


def get_stubs_of_structures(
    struct_classe_by_name: dict,
    local_attrs,
    additional_classes,
    additional_properties_default,
) -> list:
    out_src = []
    for cls_name, cls in struct_classe_by_name.items():
        fields_info = get_all_type_info(
            cls, locals_attrs=local_attrs, additional_classes=additional_classes
        )
        method_info = get_methods_and_attributes_as_code(
            cls,
            locals_attrs=local_attrs,
            additional_classes=additional_classes,
            ignore_attributes=set(fields_info.keys()),
        )

        ordered_args = _get_ordered_args(fields_info)
        bases = _get_bases_for_structure(cls, local_attrs, additional_classes)
        bases_str = f"({', '.join(bases)})" if bases else ""
        out_src.append(f"class {cls_name}{bases_str}:")
        if not fields_info and not method_info:
            out_src.append(f"{INDENT}pass")
            out_src.append("")
            continue

        if not [m for m in method_info if m.startswith("def __init__(self")]:
            out_src.append(get_init(cls, ordered_args, additional_properties_default))
        out_src.append("")
        out_src.append(
            get_additional_structure_methods(
                cls, ordered_args, additional_properties_default
            )
        )
        out_src.append("")
        out_src.append("")

        for field_name, type_name in ordered_args.items():
            out_src.append(f"    {field_name}: {type_name}")
        out_src += [f"{INDENT}{m}" for m in method_info]
        out_src.append("\n")
    return out_src


def get_stubs_of_enums(
    enum_classes_by_name: dict, local_attrs, additional_classes
) -> list:
    out_src = []
    for cls_name, cls in enum_classes_by_name.items():

        method_info = get_methods_and_attributes_as_code(
            cls, locals_attrs=local_attrs, additional_classes=additional_classes
        )
        bases = _get_bases(cls, local_attrs, additional_classes)
        bases_str = f"({', '.join(bases)})" if bases else ""
        out_src.append(f"class {cls_name}{bases_str}:")

        enum_values = [f"{INDENT}{v.name} = enum.auto()" for v in cls] or [
            f"{INDENT}pass"
        ]
        out_src.extend(enum_values)
        out_src.append("")

        out_src.append("")
        out_src += [f"{INDENT}{m}" for m in method_info]
        out_src.append("\n")
    return out_src


def add_imports(local_attrs: dict, additional_classes, existing_imports: set) -> list:
    base_typing = [
        "Union",
        "Optional",
        "Any",
        "TypeVar",
        "Type",
        "NoReturn",
        "Iterable",
    ]
    typing_types_to_import = [t for t in base_typing if t not in existing_imports]
    base_import_statements = []
    if typing_types_to_import:
        base_import_statements.append(
            f"from typing import {', '.join(typing_types_to_import)}"
        )
    base_import_statements += [
        "from typedpy import Structure",
        "",
    ]
    extra_imports_by_name = _get_mapped_extra_imports(additional_classes)
    extra_imports = {
        f"from {get_package(v, local_attrs)} import {k}{as_something(k, local_attrs)}"
        for k, v in extra_imports_by_name.items()
        if (
            (
                k not in local_attrs
                or local_attrs[k].__module__ != local_attrs["__name__"]
            )
            and k not in existing_imports
        )
    }
    return base_import_statements + sorted(extra_imports)


def _get_enum_classes(attrs):
    res = {}
    for k, v in attrs.items():
        if (
            inspect.isclass(v)
            and issubclass(v, enum.Enum)
            and (v.__module__ == attrs["__name__"])
        ):
            res[k] = v

    return res


def _get_other_classes(attrs):
    res = {}
    for k, v in attrs.items():
        if (
            inspect.isclass(v)
            and not issubclass(v, enum.Enum)
            and not issubclass(v, Structure)
        ):
            res[k] = v

    return res


def _get_functions(attrs):
    return {
        k: v
        for k, v in attrs.items()
        if (inspect.isfunction(v) and (v.__module__ == attrs["__name__"]))
    }


def _get_bases(cls, local_attrs, additional_classes) -> list:
    res = []
    for b in cls.__bases__:
        if b is object or b.__module__ == "typing" and b is not typing.Generic:
            continue
        if not is_sqlalchemy(b):
            the_type = get_type_info(b, local_attrs, additional_classes)
            if b is typing.Generic and the_type == "Generic":
                params = [p.__name__ for p in cls.__parameters__]
                params_st = f"[{', '.join(params)}]" if params else ""
                res.append(f"{the_type}{params_st}")

            elif the_type != "Any":
                res.append(the_type)
    return res


def _get_bases_for_structure(cls, local_attrs, additional_classes) -> list:
    res = []
    for b in cls.__bases__:
        if b is object or b.__module__ == "typing":
            continue
        if b.__module__.startswith("typedpy"):
            continue
        the_type = get_type_info(b, local_attrs, additional_classes)
        if the_type != "Any":
            res.append(the_type)
    res.append("Structure")
    return res


def get_stubs_of_other_classes(
    *, other_classes, local_attrs, additional_classes, additional_imports
):
    out_src = []
    for cls_name, cls in other_classes.items():
        if cls.__module__ != local_attrs["__name__"]:
            if cls_name not in additional_imports:
                out_src += [f"class {cls_name}:", f"{INDENT}pass", ""]
            continue

        bases = _get_bases(cls, local_attrs, additional_classes)
        method_info = get_methods_and_attributes_as_code(
            cls, locals_attrs=local_attrs, additional_classes=additional_classes
        )
        bases_str = f"({', '.join(bases)})" if bases else ""
        out_src.append(f"class {cls_name}{bases_str}:")
        if not method_info:
            out_src.append(f"{INDENT}pass")
        out_src.append("")
        out_src += [f"{INDENT}{m}" for m in method_info]
        out_src.append("\n")
    return out_src


def get_typevars(attrs, additional_classes):
    res = []
    for k, v in attrs.items():
        if isinstance(v, typing.TypeVar):
            constraints = ", ".join(
                [get_type_info(x, attrs, additional_classes) for x in v.__constraints__]
            )
            constraints_s = f", {constraints}" if constraints else ""
            res.append(f'{k} = TypeVar("{k}"{constraints_s})')
    return [""] + res + [""]


def _get_pyi_path(calling_source_file):
    full_path: Path = Path(calling_source_file)
    return (full_path.parent / f"{full_path.stem}.pyi").resolve()


def _get_direct_imported_as_code(attrs: dict, additional_imports):
    imported = list(get_imports(attrs.get("__file__")))
    out_src = []
    for level, pkg_name, val_info, alias in imported:
        level_s = "." * level
        val = ".".join(val_info)
        alias = alias or val
        if val and pkg_name:
            if val != "*":
                out_src.append(f"from {level_s}{pkg_name} import {val} as {alias}")
                additional_imports.append(alias)
            else:
                out_src.append(f"from {level_s}{pkg_name} import {val}")
        elif pkg_name:
            out_src.append(f"import {level_s}{pkg_name}")
            additional_imports.append(pkg_name)
        else:
            out_src.append(f"import {level_s}{alias}")
            additional_imports.append(alias)

    return out_src


def create_pyi(
    calling_source_file,
    attrs: dict,
    additional_properties_default=True,
):
    out_src = []
    additional_imports = []
    out_src.extend(
        _get_direct_imported_as_code(attrs=attrs, additional_imports=additional_imports)
    )
    enum_classes = _get_enum_classes(attrs)
    if enum_classes and enum not in additional_imports:
        out_src += ["import enum", ""]
    struct_classes = _get_struct_classes(attrs)
    other_classes = _get_other_classes(attrs)
    functions = _get_functions(attrs)

    additional_classes = set()
    out_src += get_typevars(attrs, additional_classes=additional_classes)

    out_src += _get_consts(
        attrs,
        additional_classes=additional_classes,
        additional_imports=additional_imports,
    )

    out_src += get_stubs_of_enums(
        enum_classes, local_attrs=attrs, additional_classes=additional_classes
    )
    out_src += get_stubs_of_other_classes(
        other_classes=other_classes,
        local_attrs=attrs,
        additional_classes=additional_classes,
        additional_imports=additional_imports,
    )
    out_src += get_stubs_of_structures(
        struct_classes,
        local_attrs=attrs,
        additional_classes=additional_classes,
        additional_properties_default=additional_properties_default,
    )

    out_src += get_stubs_of_functions(
        functions, local_attrs=attrs, additional_classes=additional_classes
    )

    from_future_import = [
        (i, s) for i, s in enumerate(out_src) if s.startswith("from __future__")
    ]
    for number_of_deletions, (i, s) in enumerate(from_future_import):
        out_src = (
            out_src[: i - number_of_deletions] + out_src[i + 1 - number_of_deletions :]
        )
    out_src = (
        [x[1] for x in from_future_import]
        + add_imports(
            local_attrs=attrs,
            additional_classes=additional_classes,
            existing_imports=set(additional_imports),
        )
        + out_src
    )

    out_src = AUTOGEN_NOTE + out_src
    _write_stub(out_src=out_src, pyi_path=_get_pyi_path(calling_source_file))


def _write_stub(out_src, pyi_path):
    out_s = "\n".join(out_src)
    with open(pyi_path, "w", encoding="UTF-8") as f:
        f.write(out_s)


def _get_consts(attrs, additional_classes, additional_imports):
    def _is_of_builtin(v) -> bool:
        return isinstance(
            v, (int, float, str, dict, list, set, complex, bool, frozenset)
        )

    def _as_builtin(v) -> str:
        if isinstance(v, str):
            return ""
        return v if isinstance(v, (int, float, complex, bool)) else v.__class__()

    res = []
    annotations = attrs.get("__annotations__", None) or {}
    constants = {
        k: v
        for (k, v) in attrs.items()
        if (
            _is_of_builtin(v)
            or v is None
            or k in annotations
            or (
                not inspect.isclass(v)
                and not inspect.isfunction(v)
                and not isinstance(v, typing.TypeVar)
            )
        )
        and not k.startswith("__")
        and k not in additional_imports
    }
    for c in constants:
        the_type = (
            get_type_info(annotations[c], attrs, additional_classes)
            if c in annotations
            else get_type_info(attrs[c].__class__, attrs, additional_classes)
            if not inspect.isclass(attrs[c]) and attrs[c] is not None
            else None
        )
        type_str = f": {the_type}" if the_type else ""
        val = (
            str(doublewrap_val(_as_builtin(attrs[c])))
            if _is_of_builtin(attrs[c])
            else "None"
            if attrs[c] is None
            else ""
        )
        val_st = f" = {val}" if val else ""
        res.append(f"{c}{type_str}{val_st}")
        res.append("")
    return res


def create_stub_for_file(
    abs_module_path: str,
    src_root: str,
    stubs_root: str = None,
    *,
    additional_properties_default=True,
):
    ext = os.path.splitext(abs_module_path)[-1].lower()
    if ext != ".py":
        return
    stem = Path(abs_module_path).stem
    dir_name = str(Path(abs_module_path).parent)
    relative_dir = relpath(dir_name, src_root)
    package_name = ".".join(Path(relative_dir).parts)
    module_name = stem if stem != "__init__" else package_name
    sys.path.append(str(Path(dir_name).parent))
    sys.path.append(src_root)
    spec = importlib.util.spec_from_file_location(module_name, abs_module_path)
    the_module = importlib.util.module_from_spec(spec)
    if not the_module.__package__:
        the_module.__package__ = package_name
    spec.loader.exec_module(the_module)

    pyi_dir = (
        Path(stubs_root) / Path(relative_dir)
        if stubs_root
        else Path(abs_module_path).parent
    )
    pyi_dir.mkdir(parents=True, exist_ok=True)
    (pyi_dir / Path("__init__.pyi")).touch(exist_ok=True)

    pyi_path = (pyi_dir / f"{stem}.pyi").resolve()
    if not getattr(the_module, "__package__", None):
        the_module.__package__ = ".".join(Path(relative_dir).parts)
    create_pyi(
        str(pyi_path),
        the_module.__dict__,
        additional_properties_default=additional_properties_default,
    )


def create_pyi_ast(calling_source_file, pyi_path):
    out_src = [
        "import datetime",
        "from typing import Optional, Any, Iterable, Union",
        "from typedpy import Structure",
    ]
    additional_imports = []
    imported = list(get_imports(calling_source_file))
    found_sqlalchmy = False
    for level, pkg_name, val_info, alias in imported:
        level_s = "." * level
        val = ".".join(val_info)
        alias = alias or val
        if pkg_name and pkg_name.startswith("sqlalchemy"):
            found_sqlalchmy = True
        if val and pkg_name:
            if val != "*":
                out_src.append(f"from {level_s}{pkg_name} import {val} as {alias}")
                additional_imports.append(alias)
            else:
                out_src.append(f"from {level_s}{pkg_name} import {val}")
        elif pkg_name:
            out_src.append(f"import {level_s}{pkg_name}")
            additional_imports.append(pkg_name)
        else:
            out_src.append(f"import {level_s}{alias}")
            additional_imports.append(alias)

    if found_sqlalchmy:
        out_src.extend(["from sqlalchemy import Column"])
        out_src.extend([""] * 3)
        models, functions = get_models(calling_source_file)
        out_src += models_to_src(models)
        out_src += functions_to_str(functions)
    out_src = AUTOGEN_NOTE + out_src

    _write_stub(out_src, pyi_path)


def create_stub_for_file_using_ast(
    abs_module_path: str, src_root: str, stubs_root: str = None, **kw
):
    ext = os.path.splitext(abs_module_path)[-1].lower()
    if ext != ".py":
        return
    stem = Path(abs_module_path).stem
    dir_name = str(Path(abs_module_path).parent)
    relative_dir = relpath(dir_name, src_root)

    pyi_dir = (
        Path(stubs_root) / Path(relative_dir)
        if stubs_root
        else Path(abs_module_path).parent
    )
    pyi_dir.mkdir(parents=True, exist_ok=True)
    (pyi_dir / Path("__init__.pyi")).touch(exist_ok=True)

    pyi_path = (pyi_dir / f"{stem}.pyi").resolve()
    create_pyi_ast(abs_module_path, str(pyi_path))
