
from river import drift
from typing import Union
from torch.utils.data import DataLoader
from src.models.base_models import BaseRetrainAlgo
import numpy as np

"""
Distribution drift detector, ADWIN
"""

import torch 
torch.manual_seed(2809)
class ADWIN(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int, significance: float, data_in_dataloader:bool):
        super().__init__(T, t_offline)
        self.name = 'adwin'
        self.significance = significance 
        self.retrain = False
        self.data_in_dataloader = data_in_dataloader
        self.max_batch = np.inf
        self.batch_size = 300

    def train_offline(self, training_data, testing_data):
        self.clock = 1  # how many time we check
        if self.data_in_dataloader:
            t = self.t_offline-1
            data = training_data['datasets'][t]
            
            train_data = data['visionSeqDataset'][t, "train"]
            

            dataloader = DataLoader(train_data, batch_size=self.batch_size,
                                shuffle=True, pin_memory=False)
            dataset_size = min((len(dataloader)-1),self.max_batch *  self.batch_size) 
            
            self.min_window_length = dataset_size  # 1 dataset
            
            self.adwin_detector = drift.ADWIN(delta=self.significance,
                                            clock=self.clock, min_window_length=self.min_window_length)
            for _, y_batch, _ in dataloader:
                for y in y_batch:
                    self.adwin_detector.update(y)
        else:
            last_dataset = training_data['datasets'][self.t_offline-1]['y_train']
            dataset_size = last_dataset.shape[0]
            
            self.min_window_length = dataset_size  # 1 dataset
            
            self.adwin_detector = drift.ADWIN(delta=self.significance,
                                           clock=self.clock, min_window_length=self.min_window_length)
            for y in last_dataset:
                self.adwin_detector.update(y)

    def update_at_t(self, info):
        if self.data_in_dataloader:
            new_data = info['new_training_data']['datasets'][self.t]
            train_data = new_data['visionSeqDataset'][self.t, "train"]
            

            dataloader = DataLoader(train_data, batch_size=300,
                                shuffle=True, pin_memory=False)
            
            for _, y_batch, _ in dataloader:
                for y in y_batch:
                    self.adwin_detector.update(y)
                    if self.adwin_detector.drift_detected:
                        self.retrain = True

            if self.retrain:
                # if any drift is detected in the current batch, we reset the detector
                self.adwin_detector = drift.ADWIN(delta=self.significance,
                                                clock=self.clock, min_window_length=self.min_window_length)
                for _, y_batch, _ in dataloader:
                    for y in y_batch:
                        self.adwin_detector.update(y)
                        if self.adwin_detector.drift_detected:
                            self.retrain = True
        else:
            new_dataset = info['new_training_data']['datasets'][self.t]['y_train']
            for y in new_dataset:
                self.adwin_detector.update(y)
                if self.adwin_detector.drift_detected:
                    self.retrain = True

            if self.retrain:
                # if any drift is detected in the current batch, we reset the detector
                self.adwin_detector = drift.ADWIN(delta=self.significance,
                                                clock=self.clock, min_window_length=self.min_window_length)
                for y in new_dataset:
                    self.adwin_detector.update(y)

        self.t += 1

    # We ignore any info, we randomly retrain

    def decide(self, t: int) -> Union[bool, int]:
        retrain = self.retrain
        if retrain:
            self.most_recent_available_model = self.t
            self.retrain = False  # reset to false

        return retrain, self.most_recent_available_model
