import logging
from collections import namedtuple
from itertools import groupby
from typing import Any, Dict, Iterator, Union

from databuilder import Scoped
from databuilder.extractor.base_extractor import Extractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.models.table_metadata import ColumnMetadata, TableMetadata
from pyhocon import ConfigFactory, ConfigTree

TableKey = namedtuple("TableKey", ["schema", "table_name"])

LOGGER = logging.getLogger(__name__)


class SqliteMetadataExtractor(Extractor):
    """
    Extracts Sqlite table and column metadata from underlying meta store database using SQLAlchemyExtractor
    """

    # CONFIG KEYS
    WHERE_CLAUSE_SUFFIX_KEY = "where_clause_suffix"
    CLUSTER_KEY = "cluster_key"
    USE_CATALOG_AS_CLUSTER_NAME = "use_catalog_as_cluster_name"
    DATABASE_KEY = "database_key"

    # Default values
    DEFAULT_CLUSTER_NAME = "master"

    DEFAULT_CONFIG = ConfigFactory.from_dict(
        {
            WHERE_CLAUSE_SUFFIX_KEY: "1=1",
            CLUSTER_KEY: DEFAULT_CLUSTER_NAME,
            USE_CATALOG_AS_CLUSTER_NAME: True,
        }
    )

    SQL_STATEMENT = """
            SELECT
                '{cluster_source}' as cluster,
                '' as schema,
                m.name as name,
                '' as description,
                p.name as col_name,
                p.type as col_type,
                '' as col_description,
                row_number() over win as col_sort_order
            FROM
                sqlite_master AS m
            JOIN
                pragma_table_info(m.name) AS p
            WHERE
                m.type = 'table' AND {where_clause_suffix}
            WINDOW WIN as (ORDER BY m.name, p.name)
    """
    where = "p.type like 'text' or p.type like 'varchar%' or p.type like 'char%'"

    def init(self, conf: ConfigTree) -> None:
        conf = conf.with_fallback(SqliteMetadataExtractor.DEFAULT_CONFIG)
        self._cluster = conf.get_string(SqliteMetadataExtractor.CLUSTER_KEY)

        self._database = conf.get_string(
            SqliteMetadataExtractor.DATABASE_KEY, default="sqlite"
        )

        self.sql_stmt = SqliteMetadataExtractor.SQL_STATEMENT.format(
            where_clause_suffix=conf.get_string(
                SqliteMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY
            ),
            cluster_source=self._cluster,
        )

        self._alchemy_extractor = SQLAlchemyExtractor()
        sql_alch_conf = Scoped.get_scoped_conf(
            conf, self._alchemy_extractor.get_scope()
        ).with_fallback(
            ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt})
        )

        self.sql_stmt = sql_alch_conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL)

        LOGGER.info("SQL for sqlite metadata: %s", self.sql_stmt)

        self._alchemy_extractor.init(sql_alch_conf)
        self._extract_iter: Union[None, Iterator] = None

    def extract(self) -> Union[TableMetadata, None]:
        if not self._extract_iter:
            self._extract_iter = self._get_extract_iter()
        try:
            return next(self._extract_iter)
        except StopIteration:
            return None

    def _get_extract_iter(self) -> Iterator[TableMetadata]:
        """
        Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata
        :return:
        """
        for key, group in groupby(self._get_raw_extract_iter(), self._get_table_key):
            columns = []

            for row in group:
                last_row = row
                columns.append(
                    ColumnMetadata(
                        row["col_name"],
                        row["col_description"],
                        row["col_type"],
                        row["col_sort_order"],
                    )
                )

            yield TableMetadata(
                self._database,
                last_row["cluster"],
                last_row["schema"],
                last_row["name"],
                last_row["description"],
                columns,
            )

    def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]:
        """
        Provides iterator of result row from SQLAlchemy extractor
        :return:
        """
        row = self._alchemy_extractor.extract()
        while row:
            yield row
            row = self._alchemy_extractor.extract()

    def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]:
        """
        Table key consists of schema and table name
        :param row:
        :return:
        """
        if row:
            return TableKey(schema=row["schema"], table_name=row["name"])

        return None
