from .adaptivedataloader import AdaptiveDSSDataLoader
from eracs.selectionstrategies.SL import GradMatchStrategy
import time, copy, torch
import torch.distributed as dist
import os
import os.path as osp
class GradMatchDataLoader(AdaptiveDSSDataLoader):
    """
    Implements of GradMatchDataLoader that serves as the dataloader for the adaptive GradMatch subset selection strategy from the paper 
    :footcite:`pmlr-v139-killamsetty21a`.

    Parameters
    -----------
    train_loader: torch.utils.data.DataLoader class
        Dataloader of the training dataset
    val_loader: torch.utils.data.DataLoader class
        Dataloader of the validation dataset
    dss_args: dict
        Data subset selection arguments dictionary required for GradMatch subset selection strategy
    logger: class
        Logger for logging the information
    """
    def __init__(self, train_loader, val_loader, dss_args, cfg, logger, *args, **kwargs):
        
        """
         Constructor function
        """
        # Arguments assertion check
        assert "model" in dss_args.keys(), "'model' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        assert "loss" in dss_args.keys(), "'loss' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        if dss_args.loss.reduction != "none":
            raise ValueError("Please set 'reduction' of loss function to 'none' for adaptive subset selection strategies")
        assert "eta" in dss_args.keys(), "'eta' is a compulsory argument. Include it as a key in dss_args"
        assert "num_classes" in dss_args.keys(), "'num_classes' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        assert "linear_layer" in dss_args.keys(), "'linear_layer' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        assert "selection_type" in dss_args.keys(), "'selection_type' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        assert "valid" in dss_args.keys(), "'valid' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        assert "v1" in dss_args.keys(), "'v1' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        assert "lam" in dss_args.keys(), "'lam' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        assert "eps" in dss_args.keys(), "'eps' is a compulsory argument for GradMatch. Include it as a key in dss_args"
        self.cfg = cfg
        results_dir = osp.abspath(osp.expanduser(str(self.cfg.train_args.results_dir)))#'results/'
        subset_selection_name = self.cfg.dss_args.type #"GradMatch"
        self.all_logs_dir = os.path.join(results_dir, 
                                    self.cfg.setting,#"SL"
                                    self.cfg.dataset.name,#"cifar10"
                                    subset_selection_name,#"GradMatchPB"
                                    #"ResNet18",
                                    self.cfg.model.architecture,#"ResNet18"
                                    self.cfg.method,
                                    str(self.cfg.dss_args.fraction),#0.1
                                    str(self.cfg.dss_args.select_every),#20
                                    str(self.cfg.scheduler.type),
                                    str(self.cfg.dss_args.lam))#0
        self.grad_info_dir = os.path.join(self.all_logs_dir, "grad_info")
        os.makedirs(self.grad_info_dir, exist_ok=True)
        super(GradMatchDataLoader, self).__init__(train_loader, val_loader, dss_args, cfg, 
                                                  logger, *args, **kwargs)
        self.strategy = GradMatchStrategy(train_loader, val_loader, copy.deepcopy(dss_args.model), dss_args.loss, dss_args.eta,
                                          dss_args.device, dss_args.num_classes, dss_args.linear_layer, dss_args.selection_type,
                                          logger, dss_args.valid, dss_args.v1, dss_args.lam, dss_args.eps)
        self.train_model = dss_args.model
        self.save = 19
        self.logger.debug('Grad-match dataloader initialized. ')

    def _resample_subset_indices1(self, last_para1, last_para2, method):
        """
        Function that calls the GradMatch subset selection strategy to sample new subset indices and the corresponding subset weights.
        """
        start = time.time()
        #
        if self.is_distributed and self.rank != 0:
            return None, None, None, None, 0, 0
        #
        self.logger.debug("Epoch: {0:d}, requires subset selection. ".format(self.cur_epoch))
        # cached_state_dict = copy.deepcopy(self.train_model.state_dict())
        # clone_dict = copy.deepcopy(self.train_model.state_dict())
        cached_state_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        clone_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        if self.cfg.method == "random":
            subset_indices, subset_weights, batch_indx, batch_gammas, gradient_diff_subset, gradient_diff_random_subset = self.strategy.select_random(self.budget, clone_dict, last_para1, last_para2, method)
        elif self.cfg.method == "full":
            subset_indices, subset_weights, batch_indx, batch_gammas, gradient_diff_subset, gradient_diff_random_subset = self.strategy.select_full(self.budget, clone_dict, last_para1, last_para2, method)
        else:
            subset_indices, subset_weights, batch_indx, batch_gammas, gradient_diff_subset, gradient_diff_random_subset = self.strategy.select(self.budget, clone_dict, last_para1, last_para2, method)
        #subset_indices, subset_weights, batch_indx, batch_gammas, gradient_diff_subset, gradient_diff_random_subset = self.strategy.select(int(self.budget * (1 + xishu)), clone_dict, last_para1, last_para2)
        self.train_model.load_state_dict(cached_state_dict)
        end = time.time()
        #
        if self.is_distributed:
            subset_indices = torch.tensor(subset_indices).to(self.device)
            subset_weights = torch.tensor(subset_weights).to(self.device)
            dist.broadcast(subset_indices, src=0)
            dist.broadcast(subset_weights, src=0)
            subset_indices = subset_indices.cpu().tolist()
            subset_weights = subset_weights.cpu().tolist()
        #
        self.logger.info("Epoch: {0:d}, GradMatch subset selection finished, takes {1:.4f}. ".format(self.cur_epoch, (end - start)))
        return subset_indices, subset_weights, batch_indx, batch_gammas, gradient_diff_subset, gradient_diff_random_subset
        
    def _resample_subset_indices(self):
        """
        Function that calls the GradMatch subset selection strategy to sample new subset indices and the corresponding subset weights.
        """
        start = time.time()
        self.logger.debug("Epoch: {0:d}, requires subset selection. ".format(self.cur_epoch))
        cached_state_dict = copy.deepcopy(self.train_model.state_dict())
        clone_dict = copy.deepcopy(self.train_model.state_dict())
        subset_indices, subset_weights = self.strategy.select(self.budget, clone_dict)
        self.train_model.load_state_dict(cached_state_dict)
        end = time.time()
        self.logger.info("Epoch: {0:d}, GradMatch subset selection finished, takes {1:.4f}. ".format(self.cur_epoch, (end - start)))
        return subset_indices, subset_weights

    def evaluate_grad_subset(self, subset_indices, subset_weights, random_indices):
        #def evaluate_grad_subset(self, subset_indices, subset_weights):
        #print("evaluate_grad_subset successed")
        
        # cached_state_dict = copy.deepcopy(self.train_model.state_dict())
        # clone_dict = copy.deepcopy(self.train_model.state_dict())
        cached_state_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        clone_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        gradient_diff_subset, gradient_diff_random_subset, var, grad=self.strategy.evaluate_grad(self.budget, clone_dict, subset_indices, subset_weights, random_indices)
        #gradient_diff_subset, gradient_diff_random_subset=self.strategy.evaluate_grad(self.budget, clone_dict, subset_indices, subset_weights)
        self.train_model.load_state_dict(cached_state_dict)
        #

        self.save += 1
        # print(self.save)
        # 保存文件名，使用 self.save 编号
        save_path = os.path.join(self.grad_info_dir, f'gradients_epoch_{self.save}.pt')
        # 保存 trn_gradients
        torch.save(grad.cpu(), save_path)
        # 打印提示信息
        print(f"✅ Saved trn_gradients at {save_path}")
        #
        return gradient_diff_subset, gradient_diff_random_subset, var

    def calculate_CV(self):
        # cached_state_dict = copy.deepcopy(self.train_model.state_dict())
        # clone_dict = copy.deepcopy(self.train_model.state_dict())
        cached_state_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        clone_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        CV = self.strategy.evaluate_CV(self.budget, clone_dict)
        self.train_model.load_state_dict(cached_state_dict)
        return CV

    def _record_para(self):
        # clone_dict = copy.deepcopy(self.train_model.state_dict())
        clone_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        return clone_dict

    def record_grad(self):
        # cached_state_dict = copy.deepcopy(self.train_model.state_dict())
        # clone_dict = copy.deepcopy(self.train_model.state_dict())
        cached_state_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        clone_dict = {k: v.detach().cpu().clone() for k, v in self.train_model.state_dict().items()}
        self.record_gradiant(self, clone_dict)
        self.train_model.load_state_dict(cached_state_dict)
        return 0
    
    def update_model(self, para):
        self.model.load_state_dict(para)
        return self.model
    
    # def choose_or_not(self, model_para, criterion):
    #     decision = self.strategy.Loss_drift_check(model_para, criterion)
    #     return decision