from typing import List, Optional
from src.factory import SeqDataset
from src.loss_matrix_utils import generate_loss_matrix, train_f_on_datasets
from src.models.base_models import BaseRetrainAlgo
from reporter import Reporter
from retraining_task_processing.f_generator import PilotPredictorFGenerator, PredictorFGenerator
import numpy as np

class RetrainingExp:
    def __init__(
        self,
        models: List[BaseRetrainAlgo],
        seqdataset: SeqDataset,
        reporter: Reporter,
        f_generator: PredictorFGenerator,
        cfg: dict
    ):
        self.models = {model.name: model for model in models}
        self.seqdataset = seqdataset
        self.reporter = reporter
        self.f_generator = f_generator
        self.cfg_exp = cfg['experiment']
        self.mode = cfg['experiment']['mode']

        self.offline_t = cfg['dataset']['offline_t']
        self.T = cfg['dataset']['T']
        self.relative_pe = cfg['experiment']['relative_pe']
        self.retrain_cost = self.cfg_exp['retrain_cost']
        self.verbose = False
        if self.mode == 'alpha_too_high':
            self.train_retrain_cost = self.retrain_cost*1.5
        elif self.mode == 'alpha_too_low':
            self.train_retrain_cost = self.retrain_cost*0.75
        elif self.mode == 'wrong_alpha':
            self.train_retrain_cost = cfg['model']['model_alpha']
        else:
            self.train_retrain_cost = self.retrain_cost
        if self.verbose:
            print(' Running Exp. with alpha val of',
                  self.retrain_cost, 'train', self.train_retrain_cost)

        # to store metrics and results of the experiments
        self.experiment_results = {}
        # the  T X T loss matrix, used to store *ALL* perfromance samples (both online and offline).
        self.loss_dict_matrix = {t: {} for t in range(self.T)}

    def report(self, metrics: dict, primary_metrics: Optional[tuple[str, str]]):
        if self.reporter is None:
            return
        self.reporter(metrics, primary_metrics)

    def metric_per_model_dict(self):
        return {model_name: [0] for model_name in self.models.keys()}

    def generate_offline_training_data(self, cut_off):

        train_loss_dict_matrix = {}
        train_dict_trained_f = {}

        for model_t in range(cut_off):  # all models before the cutoff time

            train_dict_trained_f[model_t] = self.dict_trained_f[model_t]
            loss_dict_of_model = self.loss_dict_matrix[model_t]
            train_model_PE_dict = {}
            for t, loss in loss_dict_of_model.items():
                if t < self. offline_t:  # we get the performance of those models on all timesteps before cutoff time
                    train_model_PE_dict[t] = loss
            train_loss_dict_matrix[model_t] = train_model_PE_dict
        offline_training_data = {
            "train_dict_trained_f": train_dict_trained_f,
            'loss_dict': train_loss_dict_matrix,
            'datasets': self.seqdataset.offline}
        if self.verbose:
            print("Providing everything up to the time", cut_off)
            self.printing_the_included_entries(train_loss_dict_matrix)
        return offline_training_data

    def printing_the_included_entries(self, some_dict):
        for key, val in some_dict.items():
            print('Included model ', key, 'with entries from',
                  min(val.keys()), 'to', max(val.keys()))

    def generate_online_test_data(self, last_model, cut_off_t):

        test_loss_dict_matrix = {}
        test_dict_trained_f = {}
        test_dict_trained_f[last_model] = self.dict_trained_f[last_model]

        # future values of the current model
        test_loss_dict_matrix[last_model] = {}
        for t in range(cut_off_t, self.T):
            test_loss_dict_matrix[last_model][t] = self.loss_dict_matrix[last_model][t]
        # future vlaues of future models
        for model_index in range(cut_off_t, self.T):
            test_dict_trained_f[model_index] = self.dict_trained_f[model_index]
            test_loss_dict_matrix[model_index] = self.loss_dict_matrix[model_index]
        datasets = {}
        # we get the dataset distr. before predicting
        for accessible_t in range(self.offline_t, cut_off_t+1):

            datasets[accessible_t] = self.seqdataset.online[accessible_t]

        datasets = datasets | self.seqdataset.offline
        online_testing_data = {'test_dict_trained_f': test_dict_trained_f,  # to be removed, should not be given
                               'loss_dict': test_loss_dict_matrix,
                               'datasets': datasets}
        if self.verbose:
            print('Generating the testing online rollout. Last model index',
                  last_model, 'Starting rollout time', cut_off_t)
            self.printing_the_included_entries(test_loss_dict_matrix)
        return online_testing_data

    def update_online_train_test(self, t_next, retrain, model_to_use):
        t_now = t_next-1
        self.training_data['datasets'][t_now] = self.seqdataset.online[t_now]
        if retrain:  # if we retrain, then we have access to that new classifier so we add it to our data

            self.training_data['train_dict_trained_f'][model_to_use] = self.dict_trained_f[model_to_use]
            self.training_data['loss_dict'][model_to_use] = {
                t_now: self.loss_dict_matrix[model_to_use][t_now]}

        else:  # if not, we have access to the new performance at t_now for the model we are using
            self.training_data['loss_dict'][model_to_use][t_now] = self.loss_dict_matrix[model_to_use][t_now]
        if self.verbose:
            self.printing_the_included_entries(self.training_data['loss_dict'])
        self.testing_data = self.generate_online_test_data(
            last_model=model_to_use, cut_off_t=t_next)

    def train_offline(self):  # the training, or offline stage.
        # we provide the training data to each model.
        self.training_data = self.generate_offline_training_data(
            cut_off=self.offline_t)
        self.testing_data = self.generate_online_test_data(
            last_model=self.offline_t-1, cut_off_t=self.offline_t)
        training_metrics_per_model = self.metric_per_model_dict()

        for model_name, model in self.models.items():
            # we set the retrain cost used by the training algo
            model.set_retrain_cost(self.train_retrain_cost)
            training_metrics = model.train_offline(
                self.training_data, self.testing_data)
            training_metrics_per_model[model_name] = training_metrics
        self.experiment_results["training_metrics_per_model"] = training_metrics_per_model

    """
    This function will go through each timestep, train all f() functions, evaluate them at each subsequent timestep t, 
    and store the performance in the loss matrix
    """

    def precompute(self, dict_of_things_needed=None):
        if dict_of_things_needed is None:
            # dict to store the trained models.
            timesteps_to_generate = range(self.T)
            self.dict_trained_f = train_f_on_datasets(timesteps_to_generate,
                                                    self.seqdataset,  self.f_generator)
            if 'pilot' in self.models.keys():  # if the pilot algo is here, we generate the simpler models as well
                pilot_f_generator = PilotPredictorFGenerator(self.cfg_exp)
                self.pilot_dict_trained_f = train_f_on_datasets(timesteps_to_generate,
                                                                self.seqdataset, pilot_f_generator)
                self.pilot_loss_dict_matrix = generate_loss_matrix(
                    timesteps_to_generate, self.seqdataset, pilot_f_generator, self.pilot_dict_trained_f, self.relative_pe)

            self.loss_dict_matrix = generate_loss_matrix(
                timesteps_to_generate, self.seqdataset, self.f_generator, self.dict_trained_f, self.relative_pe)
            
            
            dict_of_things_needed = {
            'loss_dict_matrix': self.loss_dict_matrix,
            'dict_trained_f': self.dict_trained_f,
            'seqdataset' :self.seqdataset
            }
            
            
        else:
            self.loss_dict_matrix = dict_of_things_needed['loss_dict_matrix']
            self.dict_trained_f = dict_of_things_needed['dict_trained_f']
            self.seqdataset = dict_of_things_needed['seqdataset']
            
            timesteps_to_generate = range(self.T)
            if 'pilot' in self.models.keys():  # if the pilot algo is here, we generate the simpler models as well
                pilot_f_generator = PilotPredictorFGenerator(self.cfg_exp)
                self.pilot_dict_trained_f = train_f_on_datasets(timesteps_to_generate,
                                                                self.seqdataset, pilot_f_generator)
                self.pilot_loss_dict_matrix = generate_loss_matrix(
                    timesteps_to_generate, self.seqdataset, pilot_f_generator, self.pilot_dict_trained_f, self.relative_pe)
        # work around for CARA, we should not need that
        self.all_data = {'dict_trained_f': self.dict_trained_f,
                                'datasets': self.seqdataset.offline | self.seqdataset.online}
        try:
            self.experiment_results['dataset'] = self.seqdataset.get_interesting_datasets()
        except Exception:
            pass
            
        # compute the min bound
        loss_dict_matrix = dict_of_things_needed['loss_dict_matrix']
        for model_ind, val in loss_dict_matrix.items():
            past_pe = None
            max_diff = 0
            for t, pe in val.items():
                if past_pe is None:
                    diff = pe
                    past_pe = pe
                else:
                    diff = np.abs(pe-past_pe)
                    past_pe = pe
                if diff > max_diff:
                    max_diff = diff
       # print('max diff', max_diff)
        return dict_of_things_needed
        # we store some datasets for plotting

        
    """
    This will go through each self.models, follow their retraining strategy and see how they do by evaluated their total cost.
    """

    def eval(self):

        for model_name, model in self.models.items():  # work around for CARA, we should not need to do that
            model.initialize_eval(self.all_data)
        # we start at zero cost. [ [0,0,..], [Cost trial 1, Cost trial 2 ..], ... [Cost trial 1, Cost trial 2, ...]_T ]

        # to store the costs
        C_per_model_t = self.metric_per_model_dict()
        loss_per_model_t = self.metric_per_model_dict()
        # to store the decisions
        decisions_per_model = self.metric_per_model_dict()
        info = None

        for model_name, C_per_t in C_per_model_t.items():
            loss_per_t = loss_per_model_t[model_name]
            for t in range(self.offline_t, self.T):
                C = C_per_t[-1]  # obtain the previous costs
                # get what the model proposes to do
                retraining, model_to_use = self.models[model_name].decide(t)
                if self.verbose:
                    print('at timestep', t, 'we chose to retrain:',
                          retraining, 'using model ', model_to_use)
                if retraining:  # if we choose to retrain, we add the cost of retraining
                    if model_name == 'pilot':
                        C = self.retrain_cost * \
                            self.models[model_name].discounted_cost + C
                    else:
                        C = self.retrain_cost + C

                # obtain the cost of perfromance
                if model_name == 'pilot':
                    new_loss = self.pilot_loss_dict_matrix[model_to_use][t]
                else:
                    new_loss = self.loss_dict_matrix[model_to_use][t]
                C = new_loss+C

                C_per_t.append(C)
                loss_per_t.append(new_loss)

                if t < self.T-1:  # if we are not at the end

                    # we update the available training data and testing data
                    self.update_online_train_test(
                        t_next=t+1, retrain=retraining, model_to_use=model_to_use)
                    info = {'new_training_data': self.training_data,
                            'new_testing_data': self.testing_data}
                    # signal to the retraining model that we are at the next t.
                    self.models[model_name].update_at_t(info)
                decisions_per_model[model_name].append(int(retraining))
                wandb_metrics = {"decision": int(retraining), "cost_t": C,
                                 "model_name": model_name, 'step': t, 'retrain_cost': self.retrain_cost, 'loss': new_loss}
                self.report(wandb_metrics, primary_metrics=(
                    "cost_t", "minimize"))
            if self.verbose:
                print('Schedule : ', decisions_per_model[model_name][1:])
            metrics = {"decision": retraining, "cost": C,
                       "model_name": model_name, 'step': t, 'retrain_cost': self.retrain_cost}
            self.report(metrics,  primary_metrics=("cost", "minimize"))

        self.experiment_results["C_per_model_t"] = C_per_model_t
        self.experiment_results["retraining_decisions"] = decisions_per_model
        self.experiment_results["loss_per_model_per_t"] = loss_per_model_t
