
import itertools
from typing import List
from sklearn.linear_model import ElasticNet, ElasticNetCV
from src.models.base_models import BaseRetrainAlgo
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import  WhiteKernel
import numpy as np
import numpy as np
from scipy.optimize import curve_fit
from src.models.optimal_schedule import find_optimal_schedule
from src.models.perf_forecast.features_utils import aggregate_pred, create_X_y_data
from sklearn.gaussian_process.kernels import RBF
           
           
"""
Our Predictor
"""


class PerformanceForecaster(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int,  relative_pe: bool, variant: str, loss_is_01: bool, features: List[str], unc, var):
        super().__init__(T, t_offline, relative_pe)
        self.name = 'PF'
        self.variant = variant
        self.features = features
        self.unc = unc
        self.verbose = False
        self.var = var
        self.loss_is_01 = loss_is_01

    def train_PE_predictor(self, X_train, y_train):        
        self.model = ElasticNet().fit(X_train, y_train)
        r2 = self.model.score(X_train, y_train)
        training_metrics = {'r2': r2}

        return training_metrics 

    def predict(self, X_test):
        variance_prediction = None
        y_pred = self.model.predict(X_test)
        return y_pred, variance_prediction

    def assign_var_per_pred(self, training_metrics):
        loss_var = {}
        for model_t in range(self.most_recent_available_model, self.T):
            loss_var[model_t] = {}
            for t in range(model_t, self.T):
                loss_var[model_t][t] = training_metrics['training_mse']      
        return loss_var

    def train_offline(self, training_data, testing_data):

        self.X_train, self.y_train = create_X_y_data(
            training_data, self.features)
        self.X_test, self.y_test = create_X_y_data(testing_data, self.features)

        predicted_loss, training_metrics = self.train(
            self.X_train, self.y_train, self.X_test, self.y_test)

        schedules_to_consider = list(itertools.product(
            [0, 1], repeat=self.T-self.t_offline))
        loss_var = self.assign_var_per_pred(training_metrics)
        optimal_theta = find_optimal_schedule(
            predicted_loss, schedules_to_consider, self.retrain_cost, now=self.t_offline, index_last_trained_model=self.t_offline-1, loss_is_01=self.loss_is_01, unc=self.unc, loss_var=loss_var)

        self.fixed_retrain_indices = list(
            np.argwhere(optimal_theta) + self.t_offline)

        return training_metrics

    def train(self, X_train, y_train, X_test, y_test):
        # self.look_at_pred(X_train, y_train, X_val, y_val)
        training_metrics = self.train_PE_predictor(X_train, y_train)
        y_pred, y_var = self.predict(X_test)
        
        y_train_pred, _ = self.predict(X_train)
        training_metrics['training_mse'] = np.mean(
            (y_train_pred - y_train)**2)
    
        training_metrics['predicted_pe'] = (X_test, y_pred)
        training_metrics['training_pe'] = (X_train, y_train)
        training_metrics['val_pe'] = (X_test, y_test)
        training_metrics['error'] = y_test - y_pred
        predicted_loss = aggregate_pred(X_test, y_pred)
        return predicted_loss, training_metrics

    def retrain(self):

        predicted_loss, training_metrics = self.train(
            self.X_train, self.y_train, self.X_test, self.y_test)

        thetas = list(itertools.product([0, 1], repeat=self.T-self.t))

        loss_var = self.assign_var_per_pred(training_metrics)

        optimal_theta = find_optimal_schedule(
            predicted_loss, thetas, self.retrain_cost, now=self.t, index_last_trained_model=self.most_recent_available_model, loss_is_01=self.loss_is_01, unc=self.unc, loss_var=loss_var)
        self.fixed_retrain_indices = list(
            np.argwhere(optimal_theta) + self.t)

    def update_at_t(self, info):
        self.X_train, self.y_train = create_X_y_data(
            info['new_training_data'], self.features)
        self.X_test, self.y_test = create_X_y_data(
            info['new_testing_data'], self.features)
        self.t += 1
        self.retrain()

    def decide(self, t):

        if t in self.fixed_retrain_indices:
            retrain = True
            self.most_recent_available_model = self.t
        else:
            retrain = False

        return retrain, self.most_recent_available_model
