
from river import drift
from typing import Union
import torch
from src.models.base_models import BaseRetrainAlgo
import numpy as np

"""
Distribution drift detector, KSWIN
"""


class KSWIN(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int, significance: float, data_in_dataloader:bool):
        super().__init__(T, t_offline)
        self.name = 'kswin'
        self.significance = significance
        self.retrain = False
        self.data_in_dataloader = data_in_dataloader
        self.max_batch = np.inf
        self.batch_size = 300
        
    def train_offline(self, training_data, testing_data):
        last_model = training_data['train_dict_trained_f'][self.t_offline-1]
        if self.data_in_dataloader:
            data = training_data['datasets'][self.t_offline-1]
            train_data = data['visionSeqDataset'][self.t_offline-1, "train"]
            
            
            zero_one_loss = last_model.get_zero_one(train_data, max_batch=self.max_batch)
            zero_one_loss = zero_one_loss.to(torch.float)
         
            self.kswin_detector = drift.KSWIN(alpha=self.significance)
            
            for error_f in zero_one_loss:
                self.kswin_detector.update(error_f)
            
        else:
            y_last_dataset = training_data['datasets'][self.t_offline-1]['y_train']
            X_last_dataset = training_data['datasets'][self.t_offline-1]['X_train']

            last_model = training_data['train_dict_trained_f'][self.t_offline-1]
            pred_last_model = last_model.predict(X_last_dataset)
            dataset_size = X_last_dataset.shape[0]
            

            self.kswin_detector = drift.KSWIN(alpha=self.significance)
            for i in range(dataset_size):
                error_f = 1 - (y_last_dataset[i] == pred_last_model[i])
                self.kswin_detector.update(error_f)

    def update_at_t(self, info):
        last_model = info['new_training_data']['train_dict_trained_f'][self.most_recent_available_model]
        if self.data_in_dataloader:
            data = info['new_training_data']['datasets'][self.t]
            train_data = data['visionSeqDataset'][self.t, "train"]
            zero_one_loss = last_model.get_zero_one(train_data, max_batch=self.max_batch)
            zero_one_loss = zero_one_loss.to(torch.float).to("cpu")
            self.kswin_detector = drift.KSWIN(alpha=self.significance)
            
            for error_f in zero_one_loss:
                self.kswin_detector.update(error_f)
                if self.kswin_detector.drift_detected:
                    self.retrain = True
            if self.retrain:
                # if any drift is detected in the current batch, we reset the detector
                self.kswin_detector = drift.KSWIN(alpha=self.significance)
                
                for error_f in zero_one_loss:
                    self.kswin_detector.update(error_f)
        else:
            y_new_dataset = info['new_training_data']['datasets'][self.t]['y_train']
            X_new_dataset = info['new_training_data']['datasets'][self.t]['X_train']
            
            pred_last_model = last_model.predict(X_new_dataset)
            dataset_size = X_new_dataset.shape[0]
            for i in range(dataset_size):
                error_f = 1 - (y_new_dataset[i] == pred_last_model[i])
                self.kswin_detector.update(error_f)
                if self.kswin_detector.drift_detected:
                    self.retrain = True

            if self.retrain:
                # if any drift is detected in the current batch, we reset the detector
                self.kswin_detector = drift.KSWIN(alpha=self.significance)
                
                for i in range(dataset_size):
                    error_f = 1 - (y_new_dataset[i] == pred_last_model[i])
                    self.kswin_detector.update(error_f)

        self.t += 1

    # We ignore any info, we randomly retrain

    def decide(self, t: int) -> Union[bool, int]:
        retrain = self.retrain
        if retrain:
            self.most_recent_available_model = self.t
            self.retrain = False  # reset to false

        return retrain, self.most_recent_available_model
