from ProsNet.model.model import Model
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

class ShallowModel(Model):
    def __init__(self):
        super().__init__()

    def make_predictions(self):
        feature_names = [
                  'x mean',
                  'y mean',
                  'z mean',
                  'vm mean',
                  'x std',
                  'y std',
                  'z std',
                  'vm std',
                  'x med abs dev',
                  'y med abs dev',
                  'z med abs dev',
                  'vm med abs dev',
                  'x max',
                  'y max',
                  'z max',
                  'vm max',
                  'x min',
                  'y min',
                  'z min',
                  'vm min',
                  'x sig mag area',
                  'y sig mag area',
                  'z sig mag area',
                  'vm sig mag area',
                  'x energy',
                  'y energy',
                  'z energy',
                  'vm energy',
                  'x int qu range',
                  'y int qu range',
                  'z int qu range',
                  'vm int qu range',
                  'x autocorrelation',
                  'y autocorrelation',
                  'z autocorrelation',
                  'vm autocorrelation',
                  'x spec peak pos 1',
                  'x spec peak pos 2',
                  'x spec peak pos 3',
                  'x spec peak pos 4',
                  'x spec peak pos 5',
                  'x spec peak pos 6',
                  'x spec peak freq 1',
                  'x spec peak freq 2',
                  'x spec peak freq 3',
                  'x spec peak freq 4',
                  'x spec peak freq 5',
                  'x spec peak freq 6',
                  'y spec peak pos 1',
                  'y spec peak pos 2',
                  'y spec peak pos 3',
                  'y spec peak pos 4',
                  'y spec peak pos 5',
                  'y spec peak pos 6',
                  'y spec peak freq 1',
                  'y spec peak freq 2',
                  'y spec peak freq 3',
                  'y spec peak freq 4',
                  'y spec peak freq 5',
                  'y spec peak freq 6',
                  'z spec peak pos 1',
                  'z spec peak pos 2',
                  'z spec peak pos 3',
                  'z spec peak pos 4',
                  'z spec peak pos 5',
                  'z spec peak pos 6',
                  'z spec peak freq 1',
                  'z spec peak freq 2',
                  'z spec peak freq 3',
                  'z spec peak freq 4',
                  'z spec peak freq 5',
                  'z spec peak freq 6',
                  'vm spec peak pos 1',
                  'vm spec peak pos 2',
                  'vm spec peak pos 3',
                  'vm spec peak pos 4',
                  'vm spec peak pos 5',
                  'vm spec peak pos 6',
                  'vm spec peak freq 1',
                  'vm spec peak freq 2',
                  'vm spec peak freq 3',
                  'vm spec peak freq 4',
                  'vm spec peak freq 5',
                  'vm spec peak freq 6',
                  'x spec power band 1',
                  'x spec power band 2',
                  'x spec power band 3',
                  'x spec power band 4',
                  'y spec power band 1',
                  'y spec power band 2',
                  'y spec power band 3',
                  'y spec power band 4',
                  'z spec power band 1',
                  'z spec power band 2',
                  'z spec power band 3',
                  'z spec power band 4',
                  'vm spec power band 1',
                  'vm spec power band 2',
                  'vm spec power band 3',
                  'vm spec power band 4'
                  ]

        feature_set_scaled = self.scaler.transform(self.dataset)

        model_set = pd.DataFrame(data=feature_set_scaled, columns=feature_names)

        prediction_features = ('x mean', 'y mean', 'z mean', 'vm mean', 'x std', 'x max', 'y max',
       'z max', 'vm max', 'x min', 'y min', 'x spec peak freq 1',
       'x spec peak freq 2', 'x spec peak freq 3', 'x spec peak freq 4',
       'x spec peak freq 5', 'x spec peak freq 6', 'y spec peak freq 1',
       'y spec peak freq 2', 'y spec peak freq 3', 'y spec peak freq 4',
       'y spec peak freq 5', 'y spec peak freq 6', 'z spec peak freq 1',
       'z spec peak freq 2', 'z spec peak freq 3', 'z spec peak freq 4',
       'z spec peak freq 5', 'z spec peak freq 6', 'vm spec peak freq 1',
       'vm spec peak freq 2', 'vm spec peak freq 3', 'vm spec peak freq 4',
       'vm spec peak freq 5', 'vm spec peak freq 6')

        model_set = model_set[model_set.columns.intersection(prediction_features)]
        self.predictions = self.model.predict(model_set)