"""Transformed Dataset."""
import numpy as np
from scipy.sparse import csr_matrix

from ..feature import interaction_consumed
from ..utils.sampling import NegativeSampling


class TransformedSet:
    """Dataset after transforming.

    Often generated by calling functions in ``DatasetPure`` or ``DatasetFeat``,
    then ``TransformedSet`` is used in formal training.

    Parameters
    ----------
    user_indices : numpy.ndarray or None, default: None
        All user rows in data, represented in inner id.
    item_indices : numpy.ndarray or None, default: None
        All item rows in data, represented in inner id.
    labels : numpy.ndarray or None, default: None
        All labels in data.
    sparse_indices : numpy.ndarray or None, default: None
        All sparse rows in data, represented in inner id.
    dense_values : numpy.ndarray or None, default: None
        All dense rows in data.
    train : bool, default: True
        Whether it is train data.

    See Also
    --------
    :class:`~libreco.data.dataset.DatasetPure`
    :class:`~libreco.data.dataset.DatasetFeat`
    """

    def __init__(
        self,
        user_indices=None,
        item_indices=None,
        labels=None,
        sparse_indices=None,
        dense_values=None,
        train=True,
    ):
        self._user_indices = user_indices
        self._item_indices = item_indices
        self._labels = labels
        self._sparse_indices = sparse_indices
        self._dense_values = dense_values
        self.has_sampled = False
        if train:
            self._sparse_interaction = csr_matrix(
                (labels, (user_indices, item_indices)), dtype=np.float32
            )
        if not train:
            self.user_consumed, _ = interaction_consumed(user_indices, item_indices)

        self.user_indices_orig = None
        self.item_indices_orig = None
        self.labels_orig = None
        self.sparse_indices_orig = None
        self.dense_values_orig = None

    def build_negative_samples(
        self, data_info, num_neg=1, item_gen_mode="random", seed=42
    ):
        """Perform negative sampling on all the data contained.

        Parameters
        ----------
        data_info : DataInfo
            Object contains data information.
        num_neg : int, default: 1
            Number of negative samples for each positive sample.
        item_gen_mode : str, default: 'random'
            Sampling strategy, currently only 'random' is supported.
        seed : int, default: 42
            Random seed.
        """
        self.has_sampled = True
        self.user_indices_orig = self._user_indices
        self.item_indices_orig = self._item_indices
        self.labels_orig = self._labels
        self.sparse_indices_orig = self._sparse_indices
        self.dense_values_orig = self._dense_values
        self._sampling_impl(data_info, num_neg, item_gen_mode, seed)

    def _sampling_impl(self, data_info, num_neg=1, item_gen_mode="random", seed=42):
        sparse_part = False if self.sparse_indices is None else True
        dense_part = False if self.dense_values is None else True
        neg = NegativeSampling(
            self, data_info, num_neg, sparse=sparse_part, dense=dense_part
        )

        (
            self._user_indices,
            self._item_indices,
            self._labels,
            self._sparse_indices,
            self._dense_values,
        ) = neg.generate_all(seed=seed, item_gen_mode=item_gen_mode)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        """Get a slice of data."""
        pure_part = (
            self.user_indices[index],
            self.item_indices[index],
            self.labels[index],
        )
        sparse_part = (
            (self.sparse_indices[index],)
            if self.sparse_indices is not None
            else (None,)
        )
        dense_part = (
            (self.dense_values[index],) if self.dense_values is not None else (None,)
        )
        return pure_part + sparse_part + dense_part

    @property
    def user_indices(self):
        """All user rows in data"""
        return self._user_indices

    @property
    def item_indices(self):
        """All item rows in data"""
        return self._item_indices

    @property
    def sparse_indices(self):
        """All sparse rows in data"""
        return self._sparse_indices

    @property
    def dense_values(self):
        """All dense rows in data"""
        return self._dense_values

    @property
    def labels(self):
        """All labels in data"""
        return self._labels

    @property
    def sparse_interaction(self):
        """User-item interaction data, in :class:`scipy.sparse.csr_matrix` format."""
        return self._sparse_interaction
