from torch.utils.data import Dataset

class TrainingTorchDataset(Dataset):
    """
    Format for numpy array
    Parameters
    ----------
    X : 2D array
        The input matrix
    y : 2D array
        The one-hot encoded target
    """

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        x, y = self.x[index], self.y[index]
        return x, y

class PredictTorchDataset(Dataset):
    """
    Format for numpy array
    Parameters
    ----------
    X : 2D array
        The input matrix
    """

    def __init__(self, x):
        self.x = x

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        x = self.x[index]
        return x
