
from river import drift
from typing import Union
import torch
import numpy as np
from src.models.base_models import BaseRetrainAlgo
from torch.utils.data import DataLoader

"""
Distribution drift detector, FHDDM
"""
import torch 
torch.manual_seed(2809)

class FHDDM(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int, significance: float, data_in_dataloader:bool):
        super().__init__(T, t_offline)
        self.name = 'fhddm'
        self.significance = significance
        self.retrain = False
        self.data_in_dataloader = data_in_dataloader
        self.max_batch = np.inf
        self.batch_size = 300

    def train_offline(self, training_data, testing_data):
        last_model = training_data['train_dict_trained_f'][self.t_offline-1]
        if self.data_in_dataloader:
            data = training_data['datasets'][self.t_offline-1]
            train_data = data['visionSeqDataset'][self.t_offline-1, "train"]
            
            
            zero_one_loss = last_model.get_zero_one(train_data, max_batch=self.max_batch)
            zero_one_loss = zero_one_loss.to(torch.float)
            dataloader = DataLoader(train_data, batch_size=self.batch_size,
                                shuffle=True, pin_memory=False)
            dataset_size = min((len(dataloader)-1),self.max_batch *  self.batch_size) 
            
            self.min_window_length = dataset_size  # 1 dataset
            self.fhddm_detector = drift.binary.FHDDM(
                sliding_window_size=self.min_window_length, confidence_level=self.significance)
            
            for error_f in zero_one_loss:
                self.fhddm_detector.update(error_f)
            
        else:
            y_last_dataset = training_data['datasets'][self.t_offline-1]['y_train']
            X_last_dataset = training_data['datasets'][self.t_offline-1]['X_train']

            
            pred_last_model = last_model.predict(X_last_dataset)
            dataset_size = X_last_dataset.shape[0]
            self.min_window_length = dataset_size  # 1 dataset

            self.fhddm_detector = drift.binary.FHDDM(
                sliding_window_size=self.min_window_length, confidence_level=self.significance)
            for i in range(dataset_size):
                error_f = 1 - (y_last_dataset[i] == pred_last_model[i])
                self.fhddm_detector.update(error_f)

        

    def update_at_t(self, info):
        last_model = info['new_training_data']['train_dict_trained_f'][self.most_recent_available_model]
        if self.data_in_dataloader:
            data = info['new_training_data']['datasets'][self.t]
            train_data = data['visionSeqDataset'][self.t, "train"]
            zero_one_loss = last_model.get_zero_one(train_data, max_batch=self.max_batch)
            zero_one_loss = zero_one_loss.to(torch.float)
            self.fhddm_detector = drift.binary.FHDDM(
                sliding_window_size=self.min_window_length, confidence_level=self.significance)
            
            for error_f in zero_one_loss:
                self.fhddm_detector.update(error_f)
                if self.fhddm_detector.drift_detected:
                    self.retrain = True
            if self.retrain:
                # if any drift is detected in the current batch, we reset the detector
                self.fhddm_detector = drift.binary.FHDDM(
                sliding_window_size=self.min_window_length, confidence_level=self.significance)
                
                for error_f in zero_one_loss:
                    self.fhddm_detector.update(error_f)
        else:
            
            y_new_dataset = info['new_training_data']['datasets'][self.t]['y_train']
            X_new_dataset = info['new_training_data']['datasets'][self.t]['X_train']
            
            
            pred_last_model = last_model.predict(X_new_dataset)
            dataset_size = X_new_dataset.shape[0]
            for i in range(dataset_size):
                error_f = 1 - (y_new_dataset[i] == pred_last_model[i])
                self.fhddm_detector.update(error_f)
                if self.fhddm_detector.drift_detected:
                    self.retrain = True

            if self.retrain:
                # if any drift is detected in the current batch, we reset the detector
                self.fhddm_detector = drift.binary.FHDDM(
                sliding_window_size=self.min_window_length, confidence_level=self.significance)
                
                for i in range(dataset_size):
                    error_f = 1 - (y_new_dataset[i] == pred_last_model[i])
                    self.fhddm_detector.update(error_f)

        self.t += 1

    # We ignore any info, we randomly retrain

    def decide(self, t: int) -> Union[bool, int]:
        retrain = self.retrain
        if retrain:
            self.most_recent_available_model = self.t
            self.retrain = False  # reset to false

        return retrain, self.most_recent_available_model
