import logging
from abc import abstractmethod
from torch.utils.data import DataLoader
from ..dssdataloader import DSSDataLoader
from math import ceil
import matplotlib.pyplot as plt
import torch
import pandas as pd
import os

class AdaptiveDSSDataLoader(DSSDataLoader):
    """
    Implementation of AdaptiveDSSDataLoader class which serves as base class for dataloaders of other
    adaptive subset selection strategies for supervised learning framework.

    Parameters
    -----------
    train_loader: torch.utils.data.DataLoader class
        Dataloader of the training dataset
    val_loader: torch.utils.data.DataLoader class
        Dataloader of the validation dataset
    dss_args: dict
        Data subset selection arguments dictionary
    logger: class
        Logger for logging the information
    """
    def __init__(self, train_loader, val_loader, dss_args, logger, *args,
                 **kwargs):
        
        """
        Constructor function
        """
        super(AdaptiveDSSDataLoader, self).__init__(train_loader.dataset, dss_args,
                                                    logger, *args, **kwargs)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.gradient_diff_subset_history = []  # 用于记录 gradient_diff_subset 的历史值
        self.gradient_diff_random_subset_history = []  # 用于记录 gradient_diff_random_subset 的历史值
        d = 352
        torch.manual_seed(42)
        self.random_indices = torch.randint(0, d, (int(d * 0.1),), dtype=torch.long)
        
        # Arguments assertion check
        assert "select_every" in dss_args.keys(), "'select_every' is a compulsory argument. Include it as a key in dss_args"
        assert "device" in dss_args.keys(), "'device' is a compulsory argument. Include it as a key in dss_args"
        assert "kappa" in dss_args.keys(), "'kappa' is a compulsory argument. Include it as a key in dss_args"
        self.select_every = dss_args.select_every
        self.device = dss_args.device
        self.kappa = dss_args.kappa
        self.type = dss_args.type
        self.save_place = dss_args.save_place
        if dss_args.kappa > 0:
            assert "num_epochs" in dss_args.keys(), "'num_epochs' is a compulsory argument when warm starting the model(i.e., kappa > 0). Include it as a key in dss_args"
            self.select_after =  int(dss_args.kappa * dss_args.num_epochs)
            #self.warmup_epochs = ceil(self.select_after * dss_args.fraction)
            self.warmup_epochs = ceil(self.select_after)
        else:
            self.select_after = 0
            self.warmup_epochs = 0
        self.initialized = False
        self.last_para1 = None
        self.last_para2 = None
        
    
    def __iter__(self):
        """
        Iter function that returns the iterator of full data loader or data subset loader or empty loader based on the 
        warmstart kappa value.
        """
        self.initialized = True
        if self.warmup_epochs < self.cur_epoch < self.select_after:
            self.logger.debug(
                "Skipping epoch {0:d} due to warm-start option. ".format(self.cur_epoch, self.warmup_epochs))
            loader = DataLoader([])
            
        elif self.cur_epoch < self.warmup_epochs:
            self.logger.debug('Epoch: {0:d}, reading dataloader... '.format(self.cur_epoch))
            loader = self.wtdataloader
            self.logger.debug('Epoch: {0:d}, finished reading dataloader. '.format(self.cur_epoch))
        else:
            self.logger.debug('Epoch: {0:d}, reading dataloader... '.format(self.cur_epoch))
            if self.type =="GradMatchPB":
                if ((self.cur_epoch + 5) % self.select_every == 0) and (self.cur_epoch > 1):
                    self.last_para2 = self.record_para()
                if ((self.cur_epoch + 3) % self.select_every == 0) and (self.cur_epoch > 1):
                    self.last_para1 = self.record_para()
                if ((self.cur_epoch) % self.select_every == 0) and (self.cur_epoch > 1):
                    xishu = self.cur_epoch / 300
                    gradient_diff_subset, gradient_diff_random_subset = self.resample1(self.last_para1, self.last_para2, xishu)
                    #self.gradient_diff_subset_history.append(gradient_diff_subset)
                    #self.gradient_diff_random_subset_history.append(gradient_diff_random_subset)
                    #print("gradmatch with previous model")
            else:
                if ((self.cur_epoch) % self.select_every == 0) and (self.cur_epoch > 1):
                    self.resample()
                    #print("original gradmatch")
                
            loader = self.subset_loader
            self.logger.debug('Epoch: {0:d}, finished reading dataloader. '.format(self.cur_epoch))
        # evaluation part
        d = 352
        #torch.manual_seed(42)
        #random_indices = torch.randint(0, d, (int(d * 0.1),), dtype=torch.long)
        #if (self.cur_epoch >= self.select_every) and ((self.cur_epoch) % self.select_every != 0) and ((self.cur_epoch) % 2 == 1) and (self.cur_epoch >= self.warmup_epochs):
        if (self.cur_epoch >= self.select_every) and self.type =="GradMatchPB":
            gradient_diff_subset, gradient_diff_random_subset = self.evaluate_grad_subset(self.subset_batch_indx, self.subset_batch_gammas, self.random_indices)
            self.gradient_diff_subset_history.append(gradient_diff_subset)
            self.gradient_diff_random_subset_history.append(gradient_diff_random_subset)
        # record

        if self.type =="GradMatchPB":
            
            if self.cur_epoch == 299 and 0:
                epochs = list(range(len(self.gradient_diff_subset_history)))
                data = {
                    'Epoch': epochs,
                    'Gradient Diff (Subset)': self.gradient_diff_subset_history,
                    'Gradient Diff (Random)': self.gradient_diff_random_subset_history
                }

                # 创建一个DataFrame
                df = pd.DataFrame(data)

                # 保存为CSV文件
                #file_path = '/root/cords_project_retry/cords-main/benchmarks/SL/results/grad_diff_history_cifar100.csv'
                file_path = self.save_place
                df.to_csv(file_path, index=False)

                print(f"Data saved to {file_path}")            

                self.plot_gradient_diffs()        
        # end
        self.cur_epoch += 1
        return loader.__iter__()

    def plot_gradient_diffs(self):
        """绘制 gradient_diff_subset 和 gradient_diff_random_subset 的变化图"""
        #epochs = [epoch * 2 + self.warmup_epochs for epoch in range(len(self.gradient_diff_subset_history))]
        epochs = [epoch + 20 for epoch in range(len(self.gradient_diff_subset_history))]
        # If the element is a tensor, move it to the CPU and extract the value. If it's a float, leave it as is.
        gradient_diff_subset_history_cpu = [grad.cpu().item() if isinstance(grad, torch.Tensor) else grad for grad in self.gradient_diff_subset_history]
        gradient_diff_random_subset_history_cpu = [grad.cpu().item() if isinstance(grad, torch.Tensor) else grad for grad in self.gradient_diff_random_subset_history]


        plt.figure()
        plt.plot(epochs, gradient_diff_subset_history_cpu, label="Subset Gradient Difference")
        plt.plot(epochs, gradient_diff_random_subset_history_cpu, label="Random 10% Subset Gradient Difference")
        plt.xlabel('Epoch')
        plt.ylabel('Gradient Difference')
        plt.title('Gradient Difference Over Epochs')
        plt.legend()
        plt.grid(True)
        plt.show()

    def __len__(self) -> int:
        """
        Returns the length of the current data loader
        """
        if self.warmup_epochs < self.cur_epoch <= self.select_after:
            self.logger.debug(
                "Skipping epoch {0:d} due to warm-start option. ".format(self.cur_epoch, self.warmup_epochs))
            loader = DataLoader([])
            return len(loader)

        elif self.cur_epoch <= self.warmup_epochs:
            self.logger.debug('Epoch: {0:d}, reading dataloader... '.format(self.cur_epoch))
            loader = self.wtdataloader
            #self.logger.debug('Epoch: {0:d}, finished reading dataloader. '.format(self.cur_epoch))
            return len(loader)
        else:
            self.logger.debug('Epoch: {0:d}, reading dataloader... '.format(self.cur_epoch))
            loader = self.subset_loader
            return len(loader)
            
    def resample1(self, last_para1, last_para2, xishu):
        """
        Function that resamples the subset indices and recalculates the subset weights
        """
        self.subset_indices, self.subset_weights, self.subset_batch_indx, self.subset_batch_gammas, gradient_diff_subset, gradient_diff_random_subset = self._resample_subset_indices1(last_para1, last_para2, xishu)
        self.logger.debug("Subset indices length: %d", len(self.subset_indices))
        self._refresh_subset_loader()
        self.logger.debug("Subset loader initiated, args: %s, kwargs: %s", self.loader_args, self.loader_kwargs)
        self.logger.debug('Subset selection finished, Training data size: %d, Subset size: %d',
                     self.len_full, len(self.subset_loader.dataset))
        return gradient_diff_subset, gradient_diff_random_subset

    
    def record_para(self):
        return self._record_para()

    def resample(self):
        """
        Function that resamples the subset indices and recalculates the subset weights
        """
        self.subset_indices, self.subset_weights = self._resample_subset_indices()
        self.logger.debug("Subset indices length: %d", len(self.subset_indices))
        self._refresh_subset_loader()
        self.logger.debug("Subset loader initiated, args: %s, kwargs: %s", self.loader_args, self.loader_kwargs)
        self.logger.debug('Subset selection finished, Training data size: %d, Subset size: %d',
                     self.len_full, len(self.subset_loader.dataset))

    @abstractmethod
    def _resample_subset_indices(self):
        """
        Abstract function that needs to be implemented in the child classes. 
        Needs implementation of subset selection implemented in child classes.
        """
        raise Exception('Not implemented.')

    @abstractmethod
    def _resample_subset_indices1(self, last_para1, last_para2, xishu):
        """
        Abstract function that needs to be implemented in the child classes. 
        Needs implementation of subset selection implemented in child classes.
        """
        raise Exception('Not implemented.')

    @abstractmethod
    def record_grad(self):
        raise Exception('Not implemented.')

    @abstractmethod
    def evaluate_grad_subset(self, subset_indices, subset_weights):
        raise Exception('Not implemented.')
    
    @abstractmethod
    def _record_para(self):

        raise Exception('Not implemented.')