import logging
from abc import abstractmethod
from torch.utils.data import DataLoader
from ..dssdataloader import DSSDataLoader
from math import ceil
import matplotlib.pyplot as plt
import torch
import pandas as pd
import os
import sys
import pickle
import os.path as osp
import torch.nn as nn
import numpy as np
import torch.distributed as dist

class AdaptiveDSSDataLoader(DSSDataLoader):
    """
    Implementation of AdaptiveDSSDataLoader class which serves as base class for dataloaders of other
    adaptive subset selection strategies for supervised learning framework.

    Parameters
    -----------
    train_loader: torch.utils.data.DataLoader class
        Dataloader of the training dataset
    val_loader: torch.utils.data.DataLoader class
        Dataloader of the validation dataset
    dss_args: dict
        Data subset selection arguments dictionary
    logger: class
        Logger for logging the information
    """
    def __init__(self, train_loader, val_loader, dss_args, cfg, logger, *args, **kwargs):
        """
        Constructor function
        """
        super(AdaptiveDSSDataLoader, self).__init__(train_loader.dataset, dss_args,
                                                    logger, *args, **kwargs)
        self.is_distributed = dist.is_initialized()
        self.rank = dist.get_rank() if self.is_distributed else 0
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.gradient_diff_subset_history = []  # 用于记录 gradient_diff_subset 的历史值
        self.gradient_diff_random_subset_history = []  # 用于记录 gradient_diff_random_subset 的历史值
        self.variance = []
        self.recorded_epochs = []
        d = 352
        torch.manual_seed(42)
        self.random_indices = torch.randint(0, d, (int(d * 0.1),), dtype=torch.long)
        #
        self.cfg = cfg
        results_dir = osp.abspath(osp.expanduser(str(self.cfg.train_args.results_dir)))#'results/'
        subset_selection_name = self.cfg.dss_args.type #"GradMatch"
        self.all_logs_dir = os.path.join(results_dir, 
                                    self.cfg.setting,#"SL"
                                    self.cfg.dataset.name,#"cifar10"
                                    subset_selection_name,#"GradMatchPB"
                                    self.cfg.model.architecture,#"ResNet18"
                                    self.cfg.method,
                                    str(self.cfg.dss_args.fraction),#0.1
                                    str(self.cfg.dss_args.select_every),#20
                                    str(self.cfg.scheduler.type),
                                    str(self.cfg.dss_args.lam))#0
        #

        # Arguments assertion check
        assert "select_every" in dss_args.keys(), "'select_every' is a compulsory argument. Include it as a key in dss_args"
        assert "device" in dss_args.keys(), "'device' is a compulsory argument. Include it as a key in dss_args"
        assert "kappa" in dss_args.keys(), "'kappa' is a compulsory argument. Include it as a key in dss_args"
        self.select_every = dss_args.select_every
        self.device = dss_args.device
        self.kappa = dss_args.kappa
        self.type = dss_args.type
        self.record_gradiant = dss_args.record_gradiant
        self.save_place = dss_args.save_place
        self.method = cfg.method
        self.select_times = 0
        self.count = 0
        self.last_selection_epoch = 0
        self.num_epochs = dss_args.num_epochs
        #
        self.train_after = 330
        #
        self.loss_val_history = []
        self.loss_core_history = []
        self.loss_drift_history = []
        self.last_drift = 0
        self.last_selection_epoch = -100  # 避免冷却周期逻辑出错
        self.CV_count = 0
        self.cv_threshold = 0
        self.warmup_cv = []       # 存储前期 CV
        self.cv_threshold = None  # 动态阈值初始化为空
        self.warmup_cv_epochs = 120  # 预热期轮数

        if dss_args.kappa > 0:
            assert "num_epochs" in dss_args.keys(), "'num_epochs' is a compulsory argument when warm starting the model(i.e., kappa > 0). Include it as a key in dss_args"
            self.select_after =  int(dss_args.kappa * dss_args.num_epochs)
            #self.warmup_epochs = ceil(self.select_after * dss_args.fraction)
            self.warmup_epochs = ceil(self.select_after)
        else:
            self.select_after = 0
            self.warmup_epochs = 0
        self.initialized = False
        self.last_para1 = None
        self.last_para2 = None
        self.model = None
    
    def __iter__(self):
        """
        Iter function that returns the iterator of full data loader or data subset loader or empty loader based on the 
        warmstart kappa value.
        """
        self.initialized = True
        if self.warmup_epochs < self.cur_epoch < self.select_after:
            self.logger.debug(
                "Skipping epoch {0:d} due to warm-start option. ".format(self.cur_epoch, self.warmup_epochs))
            loader = DataLoader([])
            
        elif self.cur_epoch < self.warmup_epochs:
            self.logger.debug('Epoch: {0:d}, reading dataloader... '.format(self.cur_epoch))
            loader = self.wtdataloader
            self.logger.debug('Epoch: {0:d}, finished reading dataloader. '.format(self.cur_epoch))

        else:
            self.logger.debug('Epoch: {0:d}, reading dataloader... '.format(self.cur_epoch))

            if self.type =="GradMatchPB":
                if ((self.cur_epoch + 5) % self.select_every == 0) and (self.cur_epoch > 1):
                    self.last_para2 = self.record_para()
                if ((self.cur_epoch + 3) % self.select_every == 0) and (self.cur_epoch > 1):
                    self.last_para1 = self.record_para()
                if (((self.cur_epoch) % self.select_every == 0) and (self.cur_epoch > 1)) and (not self.is_distributed or self.rank == 0):
                    if self.cur_epoch < 200:
                        xishu = 1
                    else:
                        xishu = 1
                    if(self.cur_epoch == 20):
                        gradient_diff_subset, gradient_diff_random_subset = self.resample1(self.last_para1, self.last_para2, self.method)
                        self.select_times += 1
                    if self.is_distributed and self.rank != 0:
                        
                        subset_indices = torch.empty(self.budget, dtype=torch.long, device=self.device)
                        subset_weights = torch.empty(self.budget, dtype=torch.float, device=self.device)
                        dist.broadcast(subset_indices, src=0)
                        dist.broadcast(subset_weights, src=0)
                        self.subset_indices = subset_indices.cpu().tolist()
                        self.subset_weights = subset_weights.cpu().tolist()

            else:
                if ((self.cur_epoch) % self.select_every == 0) and (self.cur_epoch > 1):
                    self.resample()
                
            loader = self.subset_loader
            self.logger.debug('Epoch: {0:d}, finished reading dataloader. '.format(self.cur_epoch))

        if self.type =="GradMatchPB" and self.cur_epoch == self.num_epochs:
            #print("Total select times:",self.select_times)
            self.logger.info(f"Total select times: {self.select_times}")


        d = 352
        
        if (self.cur_epoch >= 20 and self.type == "GradMatchPB" and self.record_gradiant) or ((self.method == "gradmatch_threshold" or self.method == "gradmatch_threshold_theta") and self.cur_epoch >= 150):

            gradient_diff_subset, gradient_diff_random_subset, var = self.evaluate_grad_subset(
                self.subset_batch_indx, self.subset_batch_gammas, self.random_indices
            )
            
            self.recorded_epochs.append(self.cur_epoch)  # 仅记录实际计算的 epoch
            self.gradient_diff_subset_history.append(gradient_diff_subset)
            self.gradient_diff_random_subset_history.append(gradient_diff_random_subset)
            self.variance.append(var)
            print(self.recorded_epochs)
            print(self.gradient_diff_subset_history)
            print(self.gradient_diff_random_subset_history)

        if self.type == "GradMatchPB" and self.cur_epoch >= 20:
            CV = self.calculate_CV()
            if self.cur_epoch <= self.warmup_cv_epochs:
                self.warmup_cv.append(CV)
                if self.cur_epoch == self.warmup_cv_epochs:
                    mean_cv = sum(self.warmup_cv) / len(self.warmup_cv)
                    std_cv = (sum((x - mean_cv) ** 2 for x in self.warmup_cv) / len(self.warmup_cv)) ** 0.5
                    self.cv_threshold = mean_cv + 1 * std_cv  
            
            # 动态阈值已设置后才生效
            if self.cv_threshold is not None and CV > self.cv_threshold:
                self.CV_count += 1
            else:
                self.CV_count = 0  # reset if not consecutively high

            if self.CV_count > 3:
                gradient_diff_subset, gradient_diff_random_subset = self.resample1(self.last_para1, self.last_para2, self.method)
                self.select_times += 1
                self.CV_count = 0

        self.cur_epoch += 1
        return loader.__iter__()

    def plot_gradient_diffs(self):
        """绘制 gradient_diff_subset 和 gradient_diff_random_subset 的变化图"""
        #epochs = [epoch * 2 + self.warmup_epochs for epoch in range(len(self.gradient_diff_subset_history))]
        epochs = [epoch + 20 for epoch in range(len(self.gradient_diff_subset_history))]
        # If the element is a tensor, move it to the CPU and extract the value. If it's a float, leave it as is.
        gradient_diff_subset_history_cpu = [grad.cpu().item() if isinstance(grad, torch.Tensor) else grad for grad in self.gradient_diff_subset_history]
        gradient_diff_random_subset_history_cpu = [grad.cpu().item() if isinstance(grad, torch.Tensor) else grad for grad in self.gradient_diff_random_subset_history]


        plt.figure()
        plt.plot(epochs, gradient_diff_subset_history_cpu, label="Subset Gradient Difference")
        plt.plot(epochs, gradient_diff_random_subset_history_cpu, label="Random 10% Subset Gradient Difference")
        plt.xlabel('Epoch')
        plt.ylabel('Gradient Difference')
        plt.title('Gradient Difference Over Epochs')
        plt.legend()
        plt.grid(True)
        plt.show()

    def __len__(self) -> int:
        """
        Returns the length of the current data loader
        """
        if self.warmup_epochs < self.cur_epoch <= self.select_after:
            self.logger.debug(
                "Skipping epoch {0:d} due to warm-start option. ".format(self.cur_epoch, self.warmup_epochs))
            loader = DataLoader([])
            return len(loader)

        elif self.cur_epoch <= self.warmup_epochs:
            self.logger.debug('Epoch: {0:d}, reading dataloader... '.format(self.cur_epoch))
            loader = self.wtdataloader
            #self.logger.debug('Epoch: {0:d}, finished reading dataloader. '.format(self.cur_epoch))
            return len(loader)
        else:
            self.logger.debug('Epoch: {0:d}, reading dataloader... '.format(self.cur_epoch))
            loader = self.subset_loader
            return len(loader)
        
    def set_model(self, model):
        self.model = model

    def resample1(self, last_para1, last_para2, method):
        """
        Function that resamples the subset indices and recalculates the subset weights
        """
        self.subset_indices, self.subset_weights, self.subset_batch_indx, self.subset_batch_gammas, gradient_diff_subset, gradient_diff_random_subset = self._resample_subset_indices1(last_para1, last_para2, method)
        self.logger.debug("Subset indices length: %d", len(self.subset_indices))
        self._refresh_subset_loader()
        self.logger.debug("Subset loader initiated, args: %s, kwargs: %s", self.loader_args, self.loader_kwargs)
        self.logger.debug('Subset selection finished, Training data size: %d, Subset size: %d',
                     self.len_full, len(self.subset_loader.dataset))
        return gradient_diff_subset, gradient_diff_random_subset

    
    def record_para(self):
        return self._record_para()

    def resample(self):
        """
        Function that resamples the subset indices and recalculates the subset weights
        """
        self.subset_indices, self.subset_weights = self._resample_subset_indices()
        self.logger.debug("Subset indices length: %d", len(self.subset_indices))
        self._refresh_subset_loader()
        self.logger.debug("Subset loader initiated, args: %s, kwargs: %s", self.loader_args, self.loader_kwargs)
        self.logger.debug('Subset selection finished, Training data size: %d, Subset size: %d',
                     self.len_full, len(self.subset_loader.dataset))

    @abstractmethod
    def _resample_subset_indices(self):
        """
        Abstract function that needs to be implemented in the child classes. 
        Needs implementation of subset selection implemented in child classes.
        """
        raise Exception('Not implemented.')

    @abstractmethod
    def _resample_subset_indices(self):
        """
        Abstract function that needs to be implemented in the child classes. 
        Needs implementation of subset selection implemented in child classes.
        """
        raise Exception('Not implemented.')

    @abstractmethod
    def _resample_subset_indices1(self, last_para1, last_para2, xishu):
        """
        Abstract function that needs to be implemented in the child classes. 
        Needs implementation of subset selection implemented in child classes.
        """
        raise Exception('Not implemented.')

    @abstractmethod
    def record_grad(self):
        raise Exception('Not implemented.')

    @abstractmethod
    def evaluate_grad_subset(self, subset_indices, subset_weights):
        raise Exception('Not implemented.')
    
    @abstractmethod
    def calculate_CV(self):
        raise Exception('Not implemented')

    @abstractmethod
    def _record_para(self):

        raise Exception('Not implemented.')
    
    @abstractmethod
    def update_model(self, para):
        raise Exception('Not implemented.')