"""Module containing tools to run result analytics."""
import numpy as np
import pandas as pd

from .db import DB
from .db import DbModel
from .utils import load_project_config


def _extract_metric_results(outcome, requested_metric):
    return outcome.apply(
        lambda row: row.apply(lambda value: value[requested_metric]))


def get_results_as_dataframe(project_name,
                             table_name='results',
                             filter_git_dirty=True):
    """Returns the results stored in the database as a pandas dataframe.

    Args:
        project_name: the project name to fetch results.
        table_name: the name of the reults table.
        filter_git_dirty: defines if dirty commits are filterd.
    """
    results = pd.read_sql_table(table_name=table_name, con=DB.engine)

    if filter_git_dirty:
        results = results[results['git_is_dirty'] == False]  # noqa: E712

    return results[results['project_name'] == project_name]


def fetch_by_git_commit_id(git_commit_id):
    """Returns a query that is filtered by git commit id.

    Args:
        git_commit_id: the commit id of the returned results.
    """
    session = DB.session()
    return session.query(DbModel).filter(
        DbModel.git_commit_id == git_commit_id)


def fetch_by_row_ids(from_id, to_id=None):
    """Returns a query in the given range of ids.

    Args:
        from_id: the smalles database id included in the results of the query.
        to_id: if specified, this is the biggest database id included in the
            results of the query.
    """
    session = DB.session()
    query = session.query(DbModel).filter(DbModel.id >= from_id)
    if to_id is not None:
        return query.filter(DbModel.id <= to_id)
    return query


def fetch_by_project_name(project=None):
    """Returns a query filtered by project.

    Args:
        project: the name of the project to extract results. If None, the
            project in the dbispipeline.ini is used.
    """
    if project is None:
        project = load_project_config()['project']['name']

    session = DB.session()
    return session.query(DbModel).filter(DbModel.project_name == project)


def get_cv_epoch_evaluator_results(requested_metric=None,
                                   query_function=fetch_by_project_name):
    """Extracts CvEpochEvaluator results from the database.

    Args:
        requested_metric: allows to restrict the results to a single metric.
        query_function: a function returing a SQLAlchemy query when called.

    Returns: A tuple containing the prepared results as first element and the
        whole db entry as the second entry. The prepared results are eighter a
        pandas dataframe if a metric is requested or a dict containing a pandas
        dataframe per metric.
    """
    for run in query_function():
        if run.evaluator['class'] == 'CvEpochEvaluator':
            outcome = pd.DataFrame(run.outcome)

            if requested_metric is None:
                results = {}
                for available_metric in run.evaluator['scoring']:
                    results[available_metric] = _extract_metric_results(
                        outcome,
                        available_metric,
                    )
                yield results, run
            elif requested_metric in run.evaluator['scoring']:
                yield _extract_metric_results(outcome, requested_metric), run


def rows_to_dataframe(rows,
                      allow_git_different_rows=False,
                      allow_git_dirty_rows=False):
    """
    Converts database rows to a pandas DataFrame.

    args:
        rows: some object that iterates over rows. May be a query result or an
            actual list of rows. If this field is None or empty, an empty
            DataFrame will be returned.
        allow_git_dirty_rows: if set to true, allows that rows might have
            different git commit ids. Otherwise, an exception is thrown.
        allow_git_dirty_rows: if set to true, allows that rows have a dirty git
            status. Otherwise, an exception is thrown.

    returns: a pandas DataFrame with all columns of the database as columns.
    """
    # rows might be a query, which is not yet fetched from the database.
    if type(rows) not in [list, np.array]:
        rows = list(rows)

    if rows is None or len(rows) == 0:
        return pd.DataFrame()

    git_ids = set([row.git_commit_id for row in rows])
    if not allow_git_different_rows and len(git_ids) > 1:
        raise ValueError(f'your result contains multiple git ids: {git_ids}')

    # the DbModel objects will have additional columns that are not interesting
    # for the underlying application. In the __table__.columns field, the
    # actual list of "payload"-columns is stored.
    columns = [column.name for column in rows[0].__table__.columns]
    df_rows = []
    for row in rows:

        if not allow_git_dirty_rows and row.git_is_dirty:
            raise ValueError('your result contains dirty git rows')

        row_result = {}
        for column in columns:
            row_result[column] = getattr(row, column)
        df_rows.append(row_result)

    return pd.DataFrame(df_rows)


def _read_parameters(dictionary, prefix='', use_prefix=True):
    """Recursive helper method for extracting GridSearch param information."""
    if len(dictionary) == 0:
        return dictionary
    result = {}
    for k, v in dictionary.items():
        # parameter contains a PipelineHelper, we need recursion
        if '__selected_model' in k:
            # the [:-16] cuts off the string '__selected_model'
            key_name = prefix + k[:-16] if use_prefix else k[:-16]
            result[key_name] = str(v[0])
            result.update(_read_parameters(v[1], f'{v[0]}__'))
        else:
            key_name = prefix + k if use_prefix else k
            result[key_name] = f'{v}'
    return result


def extract_gridsearch_parameters(
    df,
    score_name,
    drop_outcome=True,
    prefix_parameter_names=True,
):
    """
    Extracts parameters from a grid search result.

    This method creates one DataFrame row for each parameter combination in the
    "outcome -> cv_results" field, and one column of each distinct parameter.
    For example, if your grid search contains a parameter `svm__C: [1, 10]`,
    then this method will add a column `C` to your DataFrame, and replace this
    row with two rows for the values 1 and 10.

    This method will recursively resolve parameters used in PipelineHelpers.

    Depending on your configurations, the output of this method may make the
    row ids no longer unique.

    before:
    row0 = {
        dataloader: XY,
        outcome: {
            cv_results: {
                params: {
                    1: { SVM__C:  1},
                    2: { SVM__C: 10},
                }
            }
        }
    }

    after:
    row0 = {
        dataloader: XY,
        SVM__C: 1,
    }
    row1 = {
        dataloader: XY,
        SVM__C: 10,
    }

    args:
        df: a pandas DataFrame object that has one column 'outcome', which
            contains dictionaries which have a field 'cv_results'. Notably,
            this is the case for the result of the rows_to_dataframe method.
        score_name: name of the field that the score should be extracted from.
        drop_outcome: if true, the resulting dataframe will no longer have the
            original 'outcome' column.
        prefix_parameter_names: This parameter only affects models which have
            a PipelineHelper. If true, the resulting parameter names are
            returned by their full name. If set to false, only the part within
            the PipelineHelper is returned.
            Note that omitting this prefix may result in multiple parameters
            with the same name, possibly leading to grouping unrelated fields.

    returns: a pandas DataFrame with all possible parameters as columns, and
        all distinct parameter combinations as rows.
    """
    result_rows = []
    for _, row in df.iterrows():
        if 'outcome' not in row or 'cv_results' not in row['outcome']:
            raise ValueError(
                'this result set does not seem to contain grid '
                "search results, missing field: row['outcome']['cv_results']")
        cv = row['outcome']['cv_results']
        for combination_id, combination in cv['params'].items():
            # prevent unnecessary copy of the outcome field, which may be big
            result_row = {k: v for k, v in row.items() if k != 'outcome'}
            if not drop_outcome:
                result_row['outcome'] = row['outcome']
            result_row[score_name] = cv[score_name][combination_id]
            result_row.update(
                _read_parameters(combination,
                                 use_prefix=prefix_parameter_names))
            result_rows.append(result_row)
    return pd.DataFrame(result_rows)
