
from typing import List
import numpy as np
from src.models.base_models import BaseRetrainAlgo
from src.models.models_helper.dp import add_retraining_cost, dp_iterative, retrains_iterative
from sklearn.metrics.pairwise import rbf_kernel
from models.cara_algorithms.threshold import optimize as optimize_thr
from models.cara_algorithms.threshold import run as run_thr
from models.cara_algorithms.periodic import optimize as optimize_per
from models.cara_algorithms.periodic import run as run_per
from models.cara_algorithms.cum_threshold import optimize as optimize_cthr
from models.cara_algorithms.cum_threshold import run as run_cthr
"""
One version of the CARA algorithm.
"""


class CARA(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int, relative_pe: bool, variant: str):
        super().__init__(T, t_offline, relative_pe)
        self.name = 'Cara_'+variant
        if self.relative_pe:
            self.name + '_rel'

        self.variant = variant
        if self.variant == 'threshold':
            self.optimize = optimize_thr
            self.run = run_thr
        elif self.variant == "periodic":
            self.optimize = optimize_per
            self.run = run_per
        elif self.variant == "cumul_threshold":
            self.optimize = optimize_cthr
            self.run = run_cthr

    def train_offline(self, training_data, testing_data):
        trained_f = training_data["train_dict_trained_f"]
        data_dict = training_data["datasets"]
        C = compute_C_cara(
            trained_f, data_dict, timesteps=range(self.t_offline), T=self.T, relative_pe=self.relative_pe)
        self.C_offline = C[:self.t_offline-1,
                           :self.t_offline-1]  # get the offline C  
        # get the parameters by optimizing on C
        results = self.optimize(self.C_offline, self.retrain_cost)
        self.parameters = results['parameters']

 


    def initialize_eval(self, all_data):
        all_f = all_data['dict_trained_f']
        all__dict = all_data['datasets']
        C = compute_C_cara(all_f, all__dict, timesteps=range(
            self.t_offline-1, self.T),  T=self.T, relative_pe=self.relative_pe)
        self.C_online = C[self.t_offline-1:, self.t_offline-1:]
        result = self.run(add_retraining_cost(
            self.C_online, self.retrain_cost), **self.parameters)
        self.fixed_retrain_indices = [
            rel_t+self.t_offline-1 for rel_t in result["retrains"]]
        a = 0
       # print(self.fixed_retrain_indices)

    def decide(self, t):
        if t in self.fixed_retrain_indices:
            retrain = True
            self.most_recent_available_model = self.t
        else:
            retrain = False
        return retrain, self.most_recent_available_model


"""
The oracle version from the CARA paper.
"""


class CARAOracle(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int, relative_pe: bool):
        super().__init__(T, t_offline, relative_pe)
        self.name = 'OracleCara'

    def train_offline(self, training_data, testing_data, all_data):

        all_f = all_data['dict_trained_f']
        all__dict = all_data['seq_datasets']

        # we compute the full TXT C matrix
        C = compute_C_cara(all_f, all__dict, timesteps=range(
            self.t_offline-1, self.T),  T=self.T, relative_pe=self.relative_pe)
        self.C_online = C[self.t_offline-1:, self.t_offline-1:]

        # gather the Dynamic Programming table
        dp = dp_iterative(add_retraining_cost(C, self.retrain_cost))
        retrains = retrains_iterative(dp)
        result = {"retrains": retrains, "num_retrains": len(
            retrains)-1, "parameters": {}}
        self.schedule = result
        self.fixed_retrain_indices = result["retrains"]

    def decide(self, t):
        if t in self.fixed_retrain_indices:
            retrain = True
            self.most_recent_available_model = self.t
        else:
            retrain = False
        return retrain, self.most_recent_available_model


"""
CARA C: C[i,j] = loss of model i at timestep j
"""


def compute_C_cara(trained_f: dict, data_dict: dict, timesteps: List[int], T: int, relative_pe: bool):
    C = np.full((T, T), np.inf)
    for i in timesteps:
        for j in range(i, timesteps[-1]+1):
            # model trained at batch i
            _model = trained_f[i]
            # data used to train model from most recent batch
            data_ind = max(j-1, i)
           # data_ind = j
            _model_data = data_dict[data_ind]
            # queries are the same as the data
            _new_data = data_dict[j]
            _query_data = _new_data
            C[i, j] = compute_qdm_diff(
                _model, _model_data, _new_data, _query_data, relative_pe)
           
    return C


def bayes_query_data_model_loss(model, data, query_data, gamma=None):
    """Return 0/1 Query Data Model overlap loss

    Args:
        model (Classifier): the trained model
        data (dict): the data dictionary
        gamma (float, optional): The hyperparameter for the RBF kernel. Defaults to None.
        query_data (dict,optional): An explicit data dictionary from which queries should be used. 
                                    Used when a different query and training data 
                                    are to required. Default None will use queries in `data` 
                                    argument
    """
    X_new = data['X_train']
    y_new = data['y_train']
    if query_data is not None:
        X_query = query_data['X_val']
    # otherwise use the queries from the data dict
    else:
        X_query = data['X_val']
    y_pred = model.predict(X_new)
    zero_one_loss = (y_pred != y_new)
    # compute pairwise similarities between queries and data

    # sim = cosine_similarity(X_query,X_new)
    # sim[sim<0]=0
    norm_X = max(np.max(X_query), np.max(X_new)) 
    sim = rbf_kernel(X_query/norm_X, X_new/norm_X, gamma=gamma)
    # print(sim.mean(axis=1))
    # find the estimated query misclassifications
    # lower is better
    query_data_model_overlap = (
        (sim@zero_one_loss)/zero_one_loss.shape[0]).sum()
    return query_data_model_overlap/zero_one_loss.shape[0]


def compute_qdm_diff(model, model_data, new_data, query_data, relative_pe):
    """Function to compute QDM difference measures.

    Args:
        model (SGDClassifier): The current model instance
        model_data (dict): The data dictionary used to train the current model
        new_data (dict): THe data dictionary corresponding to the incoming data
        query_data (dict): The dictionary corresponding to incoming queries
        qdm_fn (Optional[Callable], optional): The function to compute a single QDM. Defaults to bayes_query_data_model_loss.

    Returns:
        float: The difference between the QDM computed
    """
   
    # print(f"{QDM_t_t=}")
    # QDM with current model, data used to train current model and queries
    QDM_t_tprime = bayes_query_data_model_loss(
        model=model,
        data=model_data,
        query_data=query_data
    )
    # print(f"{QDM_t_tprime=}")
    if relative_pe:
         # QDM with current model, new data and queries
        QDM_t_t = bayes_query_data_model_loss(
            model=model,
            data=new_data,
            query_data=query_data
        )
        return QDM_t_t - QDM_t_tprime
    else:
        return QDM_t_tprime
