__all__ = ['AccumTable', 'accum_ratio', 'accum_ratiop', 'accum_cols']


import numpy as np
import warnings
from collections import OrderedDict

from .rt_numpy import full
from .rt_categorical import Categorical
from .rt_accum2 import Accum2
from .rt_enum import TypeRegister


class AccumTable(Accum2):
    """
    AccumTable is a wrapper on Accum2 that enables the creation of tables that
    combine the results of multiple reductions generated from the Accum2 object.
    The three parts of a table generated by the AccumTable gen() method are these:

    * **Inner Table** - a table of values indexed by row labels and column key names. A
      generated table contains only one inner table, but any number of inner tables may be created
      and used to create margin columns and footer rows, or as a reference for display formatting
      (future functionality).

    * **Margin Columns** - columns of values on the right margin, associated with an inner table, and
      indexed by and representing a value associated with a given row label.  Call
      set_margin_columns() to adjust them.

    * **Footer Rows** - rows of values on the bottom margin, associated with an inner table and indexed by
      and representing a value associated with a given column key.  Call
      set_footer_rows() to adjust them.

    Parameters
    ----------
    cat_rows: Categorical or an array converted to same
        The array used to create the row labels in the AccumTable
    cat_cols: Categorical or an array converted to same
        The array used to create the column keys in the AccumTable
    filter: ndarray
        Boolean mask array applied as filter before constructing the groupings
    showfilter: bool
        Whether to include groupings whose values were all filtered out
    """
    # -------------------------------------------------------
    def __init__(cls, cat_rows, cat_cols, filter=None, showfilter=False):
        pass

    def __new__(cls, cat_rows, cat_cols, filter=None, showfilter=False):
        """

        Parameters
        ----------
        cat_rows: Categorical or an array converted to same
            The array used to create the row labels in the AccumTable
        cat_cols: Categorical or an array converted to same
            The array used to create the column keys in the AccumTable
        filter: ndarray
            Boolean mask array applied as filter before constructing the groupings
        showfilter: bool
            Whether to include groupings whose values were all filtered out

        Returns
        -------
        AccumTable
            The new instance
        """
        instance = super(AccumTable, cls).__new__(cls, cat_rows, cat_cols, filter, showfilter)
        instance._inner = OrderedDict()
        instance._rows = OrderedDict()
        instance._cols = OrderedDict()
        instance._default_inner_name = None
        return instance

    # -------------------------------------------------------
    def __repr__(self):
        """
        Return a string representation of the object

        Returns
        -------
        str
            The repr string
        """
        res = 'Inner Tables: ' + str(list(self._inner.keys())) + '\n'
        res += 'Margin Columns: ' + str(list(self._cols.keys())) + '\n'
        res += 'Footer Rows: ' + str(list(self._rows.keys()))
        return res

    # -------------------------------------------------------
    def __setitem__(self, name: str, ds):
        """
        Parameters
        ----------
        name : str
            Inner table name
        ds : Dataset
            The dataset

        Raises
        ------
        IndexError
            If `name` is not a string (table name).
        ValueError
            If `ds` is not a Dataset
        """
        if not type(name) is str:
            raise IndexError('name must be a string table name')
        if not isinstance(ds, TypeRegister.Dataset):
            raise ValueError('ds must be a Dataset')
        self._inner[name] = ds
        self._rows[name] = None
        self._cols[name] = None
        self._rename_summary_row_and_col(ds, name)
        self._default_inner_name = name

    # -------------------------------------------------------
    def __getitem__(self, index: str):
        """
        Parameters
        ----------
        index : str
            Inner table name

        Returns
        -------
        Dataset
            The specified inner table

        Raises
        ------
        IndexError
            If `index` is not a string (table name).
        """
        if not type(index) is str:
            raise IndexError('Index must be a string table name')
        self._default_inner_name = index
        return self._inner[index]

    # -------------------------------------------------------
    def _rename_summary_row_and_col(self, ds, new_name: str):
        """
        Parameters
        ----------
        ds : Dataset
            The dataset
        new_name : str
            the new name for the summary column and footer row

        Returns
        -------
        Dataset
        """
        col_names = ds.summary_get_names()
        if len(col_names) == 1:
            ds.col_rename(col_names[0], new_name)
        footers = ds.footer_get_dict()
        if len(footers) == 1:
            old_name = list(footers.keys())[0]
            nd = list(footers.values())[0]
            ds.footer_remove(old_name)
            ds.footer_set_values(new_name, nd)
        return ds

    # -------------------------------------------------------
    def gen(self, table_name=None, format=None, ref_table=None, remove_blanks=True):
        """
        Generate an AccumTable view.

        Parameters
        ----------
        table_name: string or tuple (not implemented yet)
            The name of the AccumTable table to display, or a tuple of table names if
            more than one value is to be displayed in each cell (not implemented yet).
        format (not implemented yet): dict
            A dictionary used to specify the formatting of each cell in the table. The
            keys are formatting types, such as 'bold', 'color', and 'background',
            and the values are functions that are applied to the value (or tuple)
            in each table cell to determine the applicability of a formatting type.
            For example, once could set format={'bold': lambda v: v > 0} to make all
            positive values in the table bold.
        ref_table (not implemented yet): string or Dataset
            The name of the AccumTable table, or a Dataset of the same shape, to be
            referenced for formatting the displayed table (not implemented yet).
        remove_blanks: bool
            Do not display rows or columns containing all zeros or NaNs

        Returns
        -------
        Dataset
            The generated table

        Examples
        --------
        View the pnl values in the table, coloring negative values red (not implemented yet):

        >>> at.gen('pnl', format={'color': lambda v: return 'red' if v < 0 else 'black'})
        """
        # Get the displayed, inner table
        table_name = self._default_inner_name if table_name is None else table_name
        self._default_inner_name = table_name
        if table_name is None:
            raise ValueError('Must specify a table name')
        orig = self._inner[table_name]

        # Remove blanks, as required, and set the row filter
        if remove_blanks:
            (clean, row_filter, _) = orig.copy().trim(ret_filters=True)
            row_filter = row_filter if row_filter is not None else slice(None, None, None)
        else:
            clean = orig.copy()
            row_filter = slice(None, None, None)

        # Add the margin columns to the right
        summary_names = clean.summary_get_names()
        for mar_col in [col for col in list(self._cols.keys()) if col != table_name]:
            clean[mar_col] = self._inner[mar_col][row_filter, mar_col]
            summary_names += [mar_col]
        clean.summary_set_names(summary_names)

        # Add the footer rows at the bottom
        for footer_row in [row for row in list(self._rows.keys()) if row != table_name]:
            fd = list(self._inner[footer_row].footer_get_dict(footer_row).values())[0]
            delete = [k for k in fd.keys() if not k in clean.keys()]
            for key in delete:
                del fd[key]
            clean.footer_set_values(footer_row, fd)

        return clean

    # -------------------------------------------------------
    def set_margin_columns(self, cols):
        """
        Specify the names of the inner tables whose margin columns should appear in the generated AccumTable view.

        Parameters
        ----------
        cols: list of str
            The list of inner table names, in order.
        """
        self._cols = OrderedDict()
        for k in cols:
            self._cols[k] = None

    # -------------------------------------------------------
    def set_footer_rows(self, rows):
        """
        Specify the names of the inner tables whose footer rows should appear in the generated AccumTable view.

        Parameters
        ----------
        rows: list
            The list of inner table names, in order.
        """
        self._rows = OrderedDict()
        for k in rows:
            self._rows[k] = None


def accum_ratio(cat1, cat2=None, val1=None, val2=None, filt1=None, filt2=None, func1='nansum', func2=None, return_table=False, include_numer=False, include_denom=True, remove_blanks=True):
    """
    Compute a bucketed ratio of two accums, using AccumTable.

    Parameters
    ----------
    cat1 : Categorical
        First categorical label to group by.
    cat2 : Categorical, optional
        Second categorical label to group by.
    val1
        Numerator data
    val2
        Denominator data
    filt1
        Filter for `val1` data
    filt2
        Filter for `val2` data. Optional, defaults to `filter1`
    func1
        String of function name to pass into numerator AccumTable call
    func2
        String of function name to pass into denominator AccumTable call. Defaults to `func1`.
    include_numer : bool
        If set to True, include the totals from the numerator data in the output. Ignored if `return_table` is True.
    include_denom : bool
        If set to True, include the totals from the denominator data in the output. Ignored if `return_table` is True.
    return_table : bool
        If set to True, returns the whole AccumTable instead of just the gen'd ratio data.
    remove_blanks : bool
        If set to True, blanks will be removed from the output.

    Returns
    -------
    Dataset or AccumTable
        Either a view of the ratio data, or the entire AccumTable, depending on `return_table` flag.
    """
    # Handle missing inputs
    if val1 is None:
        raise ValueError('Missing argument val1')
    if (val2 is None) & (cat2 is not None) & (val1 is not None):  # Passing as accum_ratio(cat1, val1, val2), omitting cat2 argument
        val2 = val1
        val1 = cat2
        cat2 = None
    if filt1 is None:
        filt1 = full(val1.shape[0], True, dtype=bool)  # This was playa.utils.truecol
    if filt2 is None:
        filt2 = filt1
    if func2 is None:
        func2 = func1
    if cat2 is None:
        cat2 = Categorical(full(val1.shape[0], 1, dtype=np.int8), ['NotGrouped'])  # This was playa.utils.onescol

    # Handle name collisions
    for key in ['Numer', 'Denom', 'Ratio']:
        if key in cat2.categories():
            cat2.category_replace(key, key + '_')

    # Compute accum
    accum = AccumTable(cat1, cat2)

    func1 = getattr(accum, func1)
    func2 = getattr(accum, func2)
    # TODO: In the future, when arbitrary functions are allowed in Accum2 calls, handle a missing attr here by passing it in by name
    accum['Numer'] = func1(val1, filter=filt1)
    accum['Denom'] = func2(val2, filter=filt2)

    accum['Ratio'] = accum['Numer'] / accum['Denom']

    if return_table:
        return accum
    else:
        footers = [label for (label, boolean) in zip(['Numer', 'Denom'], [include_numer, include_denom]) if boolean]
        accum.set_margin_columns(footers)
        accum.set_footer_rows(footers)
        return accum.gen('Ratio', remove_blanks=remove_blanks)


def accum_ratiop(cat1, cat2=None, val=None, filter=None, func='nansum', norm_by='T', include_total=True, remove_blanks=True, filt=None):
    """
    Compute an internal ratio, either by Total (T), Row (R), or Column (C).

    Parameters
    ----------
    cat1 : Categorical
        First categorical label to group by
    cat2 : Categorical, optional
        Second categorical label to group by.
    val
        Data column
    filter: boolean column
        Filter for var data. Replacing filt.
    func : str
        String of function name to pass into AccumTable call
    norm_by : {'T', 'C', 'R'}
        what to use as the denominator
    include_total : bool
        Include the total amounts in addition to the ratios, defaults to True.
    remove_blanks : bool
        If set to True, blanks will be removed from the output; defaults to True.
    filt
        DEPRECATED FOR "filter".


    Returns
    -------
    Dataset
        AccumTable view of ratios
    """
    # Handle missing inputs
    if val is None:
        val = full(cat1.shape[0], 1, dtype=np.float64)  # This was playa.utils.onescol
    if filter is None:
        if filt is not None: # Temporary until deprecated
            warnings.warn('Kwarg "filt" is being deprecated for "filter" to align with common syntax. "filt" will be removed in a future version', FutureWarning)
            filter = filt
        else:
            filter = full(val.shape[0], True, dtype=bool)  # This was playa.utils.truecol
    if cat2 is None:
        cat2 = Categorical(full(val.shape[0], 1, dtype=np.int8), ['NotGrouped'])  # This was playa.utils.onescol


    # Compute accum
    accum = AccumTable(cat1, cat2)

    func_name = func
    func = getattr(accum, func_name)
    # TODO: In the future, when arbitrary functions are allowed in Accum2 calls, handle a missing attr here by passing it in by name
    accum['TotalRatio'] = func(val, filter=filter)
    if include_total:
        accum['Total'] = func(val, filter=filter)

    accumr = accum.gen('TotalRatio', remove_blanks=remove_blanks)
    if include_total:
        keys = accumr.keys()[1:-1]
    else:
        keys = accumr.keys()[1:]

    if norm_by.upper() == 'T':
        total = accumr.footer_get_dict()['TotalRatio']['TotalRatio']
        accumr.footer_set_values('TotalRatio', {key: 100 * item / total for (key, item) in accumr.footer_get_dict()['TotalRatio'].items()})
        for col in keys:
            accumr[col] = 100 * accumr[col] / total
    elif norm_by.upper() == 'R':
        total = accumr.footer_get_dict()['TotalRatio']['TotalRatio']
        accumr.footer_set_values('TotalRatio', {key: 100 * item / total for (key, item) in accumr.footer_get_dict()['TotalRatio'].items()})
        for col in keys:
            accumr[col] = 100 * accumr[col] / accumr.TotalRatio
    elif norm_by.upper() == 'C':
        for col in keys:
            total = accumr.footer_get_dict()['TotalRatio'][col]
            accumr[col] = 100 * accumr[col] / total
        accumr.footer_set_values('TotalRatio', {key: 100.0 for (key, item) in accumr.footer_get_dict()['TotalRatio'].items()})
    else:
        raise ValueError(f'Invalid norm_by selection: {norm_by}. Valid choices are T, R, C.')

    return accumr


def accum_cols(cat, val_list, name_list=None, filt_list=None, func_list='nansum', remove_blanks=True):
    """
    Compute multiple accum calculations on the same categorical label, output as a single dataset.

    Parameters
    ----------
    cat : Categorical
        Categorical label to group by
    val_list : list
        List of data columns. If an element is a two-element list itself, a ratio will be calculated.
        If an element is a two-element list of type [val, 'p'], an accum_ratiop-style percentile will be calculated
    name_list : list
        List of column names in the eventual dataset. Defaults to colN.
    filt_list : list
        List of filters, either one for all or one for each. Defaults to truecol.
    func_list : str or list of str
        String of function name (or list of strings of function names) to pass into AccumTable call,
        either one for all or one for each. Defaults to 'nansum'.
    remove_blanks : bool
        If set to true, blanks will be removed from the output. Defaults to True.

    Returns
    -------
    Dataset
        Accum2 view of calculated data.
    """

    # Handle mistyped inputs
    if not isinstance(cat, Categorical):
        cat = Categorical(cat)
    if not isinstance(val_list, list):
        val_list = [val_list]

    # Handle missing inputs
    if name_list is None:
        name_list = [f'col{n}' for n in range(len(val_list))]
    if filt_list is None:
        filt_list = full(val_list[0].shape[0], True, dtype=bool)  # This was playa.utils.truecol
    if not isinstance(func_list, list):
        func_list = [func_list for _ in val_list]
    if not isinstance(filt_list, list):
        filt_list = [filt_list for _ in val_list]


    # Compute accum
    temp_cat = Categorical(full(cat.shape[0], 1, dtype=np.int8), ['NotGrouped'])  # This was playa.utils.onescol
    accum = Accum2(cat, temp_cat)

    for (val, name, filt, func) in zip(val_list, name_list, filt_list, func_list):
        func_name = func
        func = getattr(accum, func_name)
        if isinstance(val, list):  # Special cases
            if isinstance(val[1], str):  # Named cases
                if val[1] in 'pP':  # accum_ratiop type
                    curr_data = accum_ratiop(cat, temp_cat, val[0], filt, func_name, 'T', False, False)
                else:
                    raise ValueError(f'Invalid accum_cols specifier "{val[1]}" in second argument for column {name}')
            else:  # accum_ratio type
                curr_data = accum_ratio(cat, temp_cat, val[0], val[1], filt, filt, func_name, func_name, remove_blanks=False)
        else:
            # must pass multiple input params as list now
            curr_data = func([accum, val], filter=filt)
        try:
            results[name] = curr_data['NotGrouped']
        except NameError:
            # Get number of keys in (potentially) multikey categorical. This only happens once.
            cat_width = len(cat.category_dict)
            results = curr_data[:, 0:cat_width]
            results.footer_remove()
            results[name] = curr_data['NotGrouped']
        footer_val = list(curr_data.footer_get_dict().values())[0].get('NotGrouped', 0.0)
        results.footer_set_values('Total', {name: footer_val})

    if remove_blanks:
        return results.trim()
    else:
        return results


# keep this as the last line
TypeRegister.AccumTable = AccumTable
