import logging
import wandb
import os
import os.path as osp
import sys
import time
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from eracs.utils.data.dataloader.SL.adaptive import GradMatchDataLoader
from eracs.utils.data.data_utils import WeightedSubset
from eracs.utils.data.datasets.SL import gen_dataset
from eracs.utils.models import *
from eracs.utils.data.data_utils.collate import *
import pickle
from datetime import datetime
import argparse
from argparse import Namespace
from dotmap import DotMap
import os
import pandas as pd
import matplotlib.pyplot as plt

class TrainClassifier:
    def __init__(self, config_file_data):
        self.cfg = config_file_data
        results_dir = osp.abspath(osp.expanduser(self.cfg.train_args.results_dir))#'results/'
        subset_selection_name = self.cfg.dss_args.type #"GradMatch"
        all_logs_dir = os.path.join(results_dir, 
                                    self.cfg.setting,#"SL"
                                    self.cfg.dataset.name,#"cifar10"
                                    subset_selection_name,#"GradMatchPB"
                                    self.cfg.model.architecture,#"ResNet18"
                                    self.cfg.method,
                                    str(self.cfg.dss_args.fraction),#0.1
                                    str(self.cfg.dss_args.select_every),#20
                                    #str(self.cfg.dataloader.batch_size),#128
                                    str(self.cfg.scheduler.type),
                                    str(self.cfg.dss_args.lam))#0

        os.makedirs(all_logs_dir, exist_ok=True)
        self.saveplace = all_logs_dir
        # print(all_logs_dir)
        # setup logger
        plain_formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
                                            datefmt="%m/%d %H:%M:%S")
        now = datetime.now()
        current_time = now.strftime("%y/%m/%d %H:%M:%S")
        self.logger = logging.getLogger(__name__+"  " + current_time)
        self.logger.setLevel(logging.INFO)
        s_handler = logging.StreamHandler(stream=sys.stdout)
        s_handler.setFormatter(plain_formatter)
        s_handler.setLevel(logging.INFO)
        self.logger.addHandler(s_handler)
        f_handler = logging.FileHandler(os.path.join(all_logs_dir, self.cfg.dataset.name + "_" +
                                                     self.cfg.dss_args.type + ".log"), mode='w')
        f_handler.setFormatter(plain_formatter)
        f_handler.setLevel(logging.DEBUG)
        self.logger.addHandler(f_handler)
        self.logger.propagate = False
    """
    ############################## Loss Evaluation ##############################
    """
    def model_eval_loss(self, data_loader, model, criterion):
        total_loss = 0
        with torch.no_grad():
            for _, (inputs, targets) in enumerate(data_loader):
                inputs, targets = inputs.to(self.cfg.train_args.device), \
                                  targets.to(self.cfg.train_args.device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                total_loss += loss.item()
        return total_loss
    """
    ############################## Model Creation ##############################
    """
    def create_model(self):
        if self.cfg.model.architecture == 'RegressionNet':
            model = RegressionNet(self.cfg.model.input_dim)
        elif self.cfg.model.architecture == 'ResNet18':
            model = ResNet18(self.cfg.model.numclasses)
            if self.cfg.dataset.name in ['cifar10', 'cifar100', 'tinyimagenet']:
                model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                model.maxpool = nn.Identity()
        elif self.cfg.model.architecture == 'ResNet34':
            model = ResNet34(self.cfg.model.numclasses)
            if self.cfg.dataset.name in ['cifar10', 'cifar100', 'tinyimagenet']:
                model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                model.maxpool = nn.Identity()
        elif self.cfg.model.architecture == 'ResNet50':
            model = ResNet50(self.cfg.model.numclasses)
            if self.cfg.dataset.name in ['cifar10', 'cifar100', 'tinyimagenet']:
                model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                model.maxpool = nn.Identity()
                # model.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
                # model.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        elif self.cfg.model.architecture == 'ResNet101':
            model = ResNet101(self.cfg.model.numclasses)
            if self.cfg.dataset.name in ['cifar10', 'cifar100', 'tinyimagenet']:
                model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                model.maxpool = nn.Identity()
        # model = ResNet18(self.cfg.model.numclasses)
        # model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        # model.maxpool = nn.Identity()
        model = model.to(self.cfg.train_args.device)
        return model
    """
    ############################## Loss Type, Optimizer and Learning Rate Scheduler ##############################
    """
    def loss_function(self):
        criterion = nn.CrossEntropyLoss()
        criterion_nored = nn.CrossEntropyLoss(reduction='none')
        return criterion, criterion_nored
    
    def log_learning_rate(self, logger, optimizer, epoch):
        lr = optimizer.param_groups[0]['lr']
        # logger.info(f"[Epoch {epoch}] Learning rate: {lr:.6f}")

    def optimizer_with_scheduler(self, model):
        optimizer = optim.SGD(model.parameters(),
                                    lr=self.cfg.optimizer.lr,
                                    momentum=self.cfg.optimizer.momentum,
                                    weight_decay=self.cfg.optimizer.weight_decay,
                                    nesterov=self.cfg.optimizer.nesterov)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
        #                                                            T_max=self.cfg.scheduler.T_max)
            # 动态调度器选择
        if self.cfg.scheduler.type == "cosine_annealing":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=self.cfg.scheduler.T_max
            )
        elif self.cfg.scheduler.type == "cosine_annealing_warm_restarts":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer,
                T_0=self.cfg.scheduler.T_0,     # 你可以在 config 中加这个字段，比如 100
                T_mult=self.cfg.scheduler.T_mult,  # 例如 2
                eta_min=self.cfg.scheduler.eta_min if hasattr(self.cfg.scheduler, "eta_min") else 0
            )
        elif self.cfg.scheduler.type == "step":
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=self.cfg.scheduler.stepsize,
                gamma=self.cfg.scheduler.gamma
            )
        else:
            scheduler = None

        return optimizer, scheduler
    @staticmethod
    def generate_cumulative_timing(mod_timing):
        tmp = 0
        mod_cum_timing = np.zeros(len(mod_timing))
        for i in range(len(mod_timing)):
            tmp += mod_timing[i]
            mod_cum_timing[i] = tmp
        return mod_cum_timing
    @staticmethod
    def save_ckpt(state, ckpt_path):
        torch.save(state, ckpt_path)
    @staticmethod
    def load_ckpt(ckpt_path, model, optimizer):
        checkpoint = torch.load(ckpt_path)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        loss = checkpoint['loss']
        metrics = checkpoint['metrics']
        return start_epoch, model, optimizer, loss, metrics
    def count_pkl(self, path):
        if not osp.exists(path):
            return -1
        return_val = 0
        file = open(path, 'rb')
        while(True):
            try:
                _ = pickle.load(file)
                return_val += 1
            except EOFError:
                break
        file.close()
        return return_val
    def train(self, **kwargs):
        """
        ############################## General Training Loop with Data Selection Strategies ##############################
        """
        # Loading the Dataset
        logger = self.logger
        if ('trainset' in kwargs) and ('validset' in kwargs) and ('testset' in kwargs) and ('num_cls' in kwargs):
            trainset, validset, testset, num_cls = kwargs['trainset'], kwargs['validset'], kwargs['testset'], kwargs['num_cls']
        else:
            #logger.info(self.cfg)
            if self.cfg.dataset.feature == 'classimb':
                trainset, validset, testset, num_cls = gen_dataset(self.cfg.dataset.datadir,
                                                                self.cfg.dataset.name,
                                                                self.cfg.dataset.feature,
                                                                classimb_ratio=self.cfg.dataset.classimb_ratio, dataset=self.cfg.dataset)
            else:
                trainset, validset, testset, num_cls = gen_dataset(self.cfg.dataset.datadir,
                                                                self.cfg.dataset.name,
                                                                self.cfg.dataset.feature, dataset=self.cfg.dataset)
        trn_batch_size = self.cfg.dataloader.batch_size
        val_batch_size = self.cfg.dataloader.batch_size
        tst_batch_size = self.cfg.dataloader.batch_size
        batch_sampler = lambda _, __ : None
        drop_last = False
        if 'collate_fn' not in self.cfg.dataloader.keys():
            collate_fn = None
        else:
            collate_fn = self.cfg.dataloader.collate_fn
        # Creating the Data Loaders
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=trn_batch_size, sampler=batch_sampler(trainset, trn_batch_size),
                                                  shuffle=False, pin_memory=True, collate_fn = collate_fn, drop_last=drop_last)

        valloader = torch.utils.data.DataLoader(validset, batch_size=val_batch_size, sampler=batch_sampler(validset, val_batch_size),
                                                shuffle=False, pin_memory=True, collate_fn = collate_fn, drop_last=drop_last)

        testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size, sampler=batch_sampler(testset, tst_batch_size),
                                                 shuffle=False, pin_memory=True, collate_fn = collate_fn, drop_last=drop_last)
	
        train_eval_loader = torch.utils.data.DataLoader(trainset, batch_size=trn_batch_size * 20, sampler=batch_sampler(trainset, trn_batch_size),
                                                  shuffle=False, pin_memory=True, collate_fn = collate_fn, drop_last=drop_last)

        val_eval_loader = torch.utils.data.DataLoader(validset, batch_size=val_batch_size * 20, sampler=batch_sampler(validset, val_batch_size),
                                                shuffle=False, pin_memory=True, collate_fn = collate_fn, drop_last=drop_last)

        test_eval_loader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size * 20, sampler=batch_sampler(testset, tst_batch_size),
                                                 shuffle=False, pin_memory=True, collate_fn = collate_fn, drop_last=drop_last)
        substrn_losses = list()  # np.zeros(cfg['train_args']['num_epochs'])
        trn_losses = list()
        val_losses = list()  # np.zeros(cfg['train_args']['num_epochs'])
        tst_losses = list()
        subtrn_losses = list()
        timing = []
        trn_acc = list()
        val_acc = list()  # np.zeros(cfg['train_args']['num_epochs'])
        tst_acc = list()  # np.zeros(cfg['train_args']['num_epochs'])
        best_acc = list()
        curr_best_acc = 0
        subtrn_acc = list()  # np.zeros(cfg['train_args']['num_epochs'])
        # Checkpoint file
        checkpoint_dir = osp.abspath(osp.expanduser(self.cfg.ckpt.dir))
        subset_selection_name = self.cfg.dss_args.type
        ckpt_dir = os.path.join(checkpoint_dir, 
                                self.cfg.setting,
                                self.cfg.dataset.name,
                                subset_selection_name,
                                self.cfg.model.architecture,
                                str(self.cfg.dss_args.fraction),
                                str(self.cfg.dss_args.select_every),
                                str(self.cfg.train_args.run))   
        ckpt_dir = self.saveplace           
        checkpoint_path = os.path.join(ckpt_dir, 'model.pt')
        os.makedirs(ckpt_dir, exist_ok=True)
        # Model Creation
        model = self.create_model()
        # logger.info("Model architecture:\n" + str(model))
        #

        #
        if self.cfg.train_args.wandb:
            wandb.watch(model)
        #Initial Checkpoint Directory
        init_ckpt_dir = os.path.abspath(os.path.expanduser("checkpoints"))
        os.makedirs(init_ckpt_dir, exist_ok=True)
        model_name = ""
        for key in self.cfg.model.keys():
            if r"/" not in str(self.cfg.model[key]):
                model_name += (str(self.cfg.model[key]) + "_")
        if model_name[-1] == "_":
            model_name = model_name[:-1]  
        print(init_ckpt_dir)  
        if not os.path.exists(os.path.join(init_ckpt_dir, model_name + ".pt")):
            ckpt_state = {'state_dict': model.state_dict()}
            # save checkpoint
            self.save_ckpt(ckpt_state, os.path.join(init_ckpt_dir, model_name + ".pt"))
        #else:
            #new model
            '''
            checkpoint = torch.load(os.path.join(init_ckpt_dir, model_name + ".pt"))
            model.load_state_dict(checkpoint['state_dict'])
            '''
            #checkpoint = torch.load(os.path.join(init_ckpt_dir, model_name + ".pt"), map_location="cuda:0")
            #model.load_state_dict(checkpoint['state_dict'])
            #model.to("cuda:0")
            
        # Loss Functions
        criterion, criterion_nored = self.loss_function()
        # Getting the optimizer and scheduler
        optimizer, scheduler = self.optimizer_with_scheduler(model)
        """
        ############################## Custom Dataloader Creation ##############################
        """
        if 'collate_fn' not in self.cfg.dss_args:
                self.cfg.dss_args.collate_fn = None
        if self.cfg.dss_args.type in ['GradMatch', 'GradMatchPB', 'GradMatch-Warm', 'GradMatchPB-Warm']:
            """
            ############################## GradMatch Dataloader Additional Arguments ##############################
            """
            self.cfg.dss_args.model = model
            self.cfg.dss_args.loss = criterion_nored
            self.cfg.dss_args.eta = self.cfg.optimizer.lr
            self.cfg.dss_args.num_classes = self.cfg.model.numclasses
            self.cfg.dss_args.num_epochs = self.cfg.train_args.num_epochs
            self.cfg.dss_args.device = self.cfg.train_args.device

            dataloader = GradMatchDataLoader(trainloader, valloader, self.cfg.dss_args, self.cfg, logger,
                                             batch_size=self.cfg.dataloader.batch_size,
                                             shuffle=self.cfg.dataloader.shuffle,
                                             pin_memory=self.cfg.dataloader.pin_memory,
                                             collate_fn = self.cfg.dss_args.collate_fn)
            # new module
            dataloader.set_model(model)
            # new module
        elif self.cfg.dss_args.type in ['Random', 'Random-Warm']:
            """
            ############################## Random Dataloader Additional Arguments ##############################
            """
            self.cfg.dss_args.device = self.cfg.train_args.device
            self.cfg.dss_args.num_epochs = self.cfg.train_args.num_epochs

            dataloader = RandomDataLoader(trainloader, self.cfg.dss_args, self.cfg, logger,
                                          batch_size=self.cfg.dataloader.batch_size,
                                          shuffle=self.cfg.dataloader.shuffle,
                                          pin_memory=self.cfg.dataloader.pin_memory, 
                                          collate_fn = self.cfg.dss_args.collate_fn)
        elif self.cfg.dss_args.type == 'Full':
            """
            ############################## Full Dataloader Additional Arguments ##############################
            """
            wt_trainset = WeightedSubset(trainset, list(range(len(trainset))), [1] * len(trainset))

            dataloader = torch.utils.data.DataLoader(wt_trainset,
                                                     batch_size=self.cfg.dataloader.batch_size,
                                                     shuffle=self.cfg.dataloader.shuffle,
                                                     pin_memory=self.cfg.dataloader.pin_memory,
                                                     collate_fn=self.cfg.dss_args.collate_fn)
        is_selcon = False
        """
        ################################################# Checkpoint Loading #################################################
        """
        #new training
        start_epoch = 0
        """
        ################################################# Training Loop #################################################
        """
        train_time = 0
        gradient_diff_subset_history = []
        gradient_diff_random_subset_history = []
        d = 352
        torch.manual_seed(42)
        random_indices = torch.randint(0, d, (int(d * 0.1),), dtype=torch.long)
        for epoch in range(start_epoch, self.cfg.train_args.num_epochs+1):
            """
            ################################################# Evaluation Loop #################################################
            """
            print_args = self.cfg.train_args.print_args#print_args=["trn_loss", "trn_acc", "val_loss", "val_acc", "tst_loss", "tst_acc", "time"]
            if (epoch % self.cfg.train_args.print_every == 0) or (epoch == self.cfg.train_args.num_epochs) or (epoch == 0):
                trn_loss = 0
                trn_correct = 0
                trn_total = 0
                val_loss = 0
                val_correct = 0
                val_total = 0
                tst_correct = 0
                tst_total = 0
                tst_loss = 0
                model.eval()
                logger_dict = {}
                if ("trn_loss" in print_args) or ("trn_acc" in print_args):
                    samples=0
                    with torch.no_grad():
                        for _, data in enumerate(train_eval_loader):
                            if is_selcon:
                                inputs, targets, _ = data
                            else:
                                inputs, targets = data

                            inputs, targets = inputs.to(self.cfg.train_args.device), \
                                              targets.to(self.cfg.train_args.device, non_blocking=True)
                            outputs = model(inputs)
                            loss = criterion(outputs, targets)
                            trn_loss += (loss.item() * train_eval_loader.batch_size)
                            samples += targets.shape[0]
                            if "trn_acc" in print_args:
                                if is_selcon: predicted = outputs
                                else: _, predicted = outputs.max(1)
                                trn_total += targets.size(0)
                                trn_correct += predicted.eq(targets).sum().item()
                        trn_loss = trn_loss/samples
                        trn_losses.append(trn_loss)
                        logger_dict['trn_loss'] = trn_loss
                    if "trn_acc" in print_args:
                        trn_acc.append(trn_correct / trn_total)
                        logger_dict['trn_acc'] = trn_correct / trn_total

                if ("val_loss" in print_args) or ("val_acc" in print_args):
                    samples =0
                    with torch.no_grad():
                        for _, data in enumerate(val_eval_loader):
                            if is_selcon:
                                inputs, targets, _ = data
                            else:
                                inputs, targets = data

                            inputs, targets = inputs.to(self.cfg.train_args.device), \
                                              targets.to(self.cfg.train_args.device, non_blocking=True)
                            outputs = model(inputs)
                            loss = criterion(outputs, targets)
                            val_loss += (loss.item() * val_eval_loader.batch_size)
                            samples += targets.shape[0]
                            if "val_acc" in print_args:
                                if is_selcon: predicted = outputs
                                else: _, predicted = outputs.max(1)
                                val_total += targets.size(0)
                                val_correct += predicted.eq(targets).sum().item()
                        val_loss = val_loss/samples
                        val_losses.append(val_loss)
                        logger_dict['val_loss'] = val_loss

                    if "val_acc" in print_args:
                        val_acc.append(val_correct / val_total)
                        logger_dict['val_acc'] = val_correct / val_total

                if ("tst_loss" in print_args) or ("tst_acc" in print_args):
                    samples =0
                    with torch.no_grad():
                        for _, data in enumerate(test_eval_loader):
                            if is_selcon:
                                inputs, targets, _ = data
                            else:
                                inputs, targets = data

                            inputs, targets = inputs.to(self.cfg.train_args.device), \
                                              targets.to(self.cfg.train_args.device, non_blocking=True)
                            outputs = model(inputs)
                            loss = criterion(outputs, targets)
                            tst_loss += (loss.item() * test_eval_loader.batch_size)
                            samples += targets.shape[0]
                            if "tst_acc" in print_args:
                                if is_selcon: predicted = outputs
                                else: _, predicted = outputs.max(1)
                                tst_total += targets.size(0)
                                tst_correct += predicted.eq(targets).sum().item()
                        tst_loss = tst_loss/samples
                        tst_losses.append(tst_loss)
                        logger_dict['tst_loss'] = tst_loss

                    if (tst_correct/tst_total) > curr_best_acc:
                        curr_best_acc = (tst_correct/tst_total)

                    if "tst_acc" in print_args:
                        tst_acc.append(tst_correct / tst_total)
                        best_acc.append(curr_best_acc)
                        logger_dict['tst_acc'] = tst_correct / tst_total
                        logger_dict['best_acc'] = curr_best_acc
                if "subtrn_acc" in print_args:
                    if epoch == 0:
                        subtrn_acc.append(0)
                        logger_dict['subtrn_acc'] = 0
                    else:    
                        subtrn_acc.append(subtrn_correct / subtrn_total)
                        logger_dict['subtrn_acc'] = subtrn_correct / subtrn_total

                if "subtrn_loss" in print_args:
                    if epoch == 0:
                        subtrn_losses.append(0)
                        logger_dict['subtrn_loss'] = 0
                    else: 
                        subtrn_losses.append(subtrn_loss / batch_idx)
                        logger_dict['subtrn_loss'] = subtrn_loss / batch_idx
                
                
                print_str = "Epoch: " + str(epoch)
                logger_dict['Epoch'] = epoch
                logger_dict['Time'] = train_time
                timing.append(train_time)
                if self.cfg.train_args.wandb:
                    wandb.log(logger_dict)

                """
                ################################################# Results Printing #################################################
                """
                for arg in print_args:
                    if arg == "val_loss":
                        print_str += " , " + "Validation Loss: " + str(val_losses[-1])
                    if arg == "val_acc":
                        print_str += " , " + "Validation Accuracy: " + str(val_acc[-1])
                    if arg == "tst_loss":
                        print_str += " , " + "Test Loss: " + str(tst_losses[-1])
                    if arg == "tst_acc":
                        print_str += " , " + "Test Accuracy: " + str(tst_acc[-1])
                        print_str += " , " + "Best Accuracy: " + str(best_acc[-1])
                    if arg == "trn_loss":
                        print_str += " , " + "Training Loss: " + str(trn_losses[-1])
                    if arg == "trn_acc":
                        print_str += " , " + "Training Accuracy: " + str(trn_acc[-1])
                    if arg == "subtrn_loss":
                        print_str += " , " + "Subset Loss: " + str(subtrn_losses[-1])
                    if arg == "subtrn_acc":
                        print_str += " , " + "Subset Accuracy: " + str(subtrn_acc[-1])
                    if arg == "time":
                        print_str += " , " + "Timing: " + str(timing[-1])

                # report metric to ray for hyperparameter optimization
                # if 'report_tune' in self.cfg and self.cfg.report_tune and len(dataloader) and epoch > 0:
                #     tune.report(mean_accuracy=np.array(val_acc).max())
                logger.info(print_str)
            subtrn_loss = 0
            subtrn_correct = 0
            subtrn_total = 0
            model.train()
            start_time = time.time()
            """
            ################################################# update model #################################################
            """
            for batch_idx, (inputs, targets, weights) in enumerate(dataloader):#会调用dataloader 的 __iter__() 
                # if batch_idx == 1:
                #     logger.info(f'[Epoch {epoch}] Batch {batch_idx}: inputs.shape = {inputs.shape}, targets[:5] = {targets[:5].tolist()}, weights[:5] = {weights[:5].tolist()}')
                inputs = inputs.to(self.cfg.train_args.device)
                targets = targets.to(self.cfg.train_args.device, non_blocking=True)
                weights = weights.to(self.cfg.train_args.device)
                optimizer.zero_grad()
                outputs = model(inputs)
                losses = criterion_nored(outputs, targets)
                #if epoch == 20:
                    #logger.info(f'Weights: {weights}')
                # loss = torch.dot(losses, weights / (weights.sum()))
                weights = weights.float() # 
                loss = torch.dot(losses, weights*2 / (weights.sum()))
                loss.backward()
                subtrn_loss += loss.item()
                optimizer.step()

                if not self.cfg.is_reg:
                    _, predicted = outputs.max(1)
                    subtrn_total += targets.size(0)
                    subtrn_correct += predicted.eq(targets).sum().item()
                if epoch >= 10 and self.cfg.dss_args.type == "GradMatchPB" and self.cfg.record_200th == True: #
                    if batch_idx % 10 == 0 or epoch == 20:
                        gradient_diff_subset, gradient_diff_random_subset = dataloader.evaluate_grad_subset(dataloader.subset_batch_indx, dataloader.subset_batch_gammas, random_indices)
                        # Record the gradient differences
                        gradient_diff_subset_history.append(gradient_diff_subset)
                        gradient_diff_random_subset_history.append(gradient_diff_random_subset)
            self.log_learning_rate(logger, optimizer, epoch)
         
                

            epoch_time = time.time() - start_time
            #scheduler.step()
            if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):
                scheduler.step(epoch + 1)
            else:
                scheduler.step()
            train_time += epoch_time
            """
            ################################################# Checkpoint Saving #################################################
            """
            if ((epoch + 1) % self.cfg.ckpt.save_every == 0) and self.cfg.ckpt.is_save:
                metric_dict = {}
                for arg in print_args:
                    if arg == "val_loss":
                        metric_dict['val_loss'] = val_losses
                    if arg == "val_acc":
                        metric_dict['val_acc'] = val_acc
                    if arg == "tst_loss":
                        metric_dict['tst_loss'] = tst_losses
                    if arg == "tst_acc":
                        metric_dict['tst_acc'] = tst_acc
                        metric_dict['best_acc'] = best_acc
                    if arg == "trn_loss":
                        metric_dict['trn_loss'] = trn_losses
                    if arg == "trn_acc":
                        metric_dict['trn_acc'] = trn_acc
                    if arg == "time":
                        metric_dict['time'] = timing
                ckpt_state = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss': self.loss_function(),
                    'metrics': metric_dict}
                # save checkpoint
                self.save_ckpt(ckpt_state, checkpoint_path)
                # checkpoint_path_epoch = os.path.join(ckpt_dir, f"model_run_{self.cfg.train_args.run}_epoch_{epoch + 1}.pt")
                # self.save_ckpt(ckpt_state, checkpoint_path_epoch)
                logger.info("Model checkpoint saved at epoch: {0:d}".format(epoch + 1))
            
        """
        ################################################# Results Summary #################################################
        """
        original_idxs = set([x for x in range(len(trainset))])
        encountered_idxs = []
        if self.cfg.dss_args.type != 'Full':
            for key in dataloader.selected_idxs.keys():
                encountered_idxs.extend(dataloader.selected_idxs[key])
            encountered_idxs = set(encountered_idxs)
            rem_idxs = original_idxs.difference(encountered_idxs)
            encountered_percentage = len(encountered_idxs)/len(original_idxs)
            logger.info("Selected Indices: ") 
            #logger.info(dataloader.selected_idxs)
            logger.info("Percentages of data samples encountered during training: %.2f", encountered_percentage)
            logger.info("Not Selected Indices: ")
            #logger.info(rem_idxs)
            if self.cfg.train_args.wandb:
                wandb.log({"Data Samples Encountered(in %)": encountered_percentage})           
        logger.info(self.cfg.dss_args.type + " Selection Run---------------------------------")
        logger.info("Final SubsetTrn: {0:f}".format(subtrn_loss))
        logger.info("Validation Loss: %.2f , Validation Accuracy: %.2f", val_loss, val_acc[-1])
        logger.info("Test Loss: %.2f, Test Accuracy: %.2f, Best Accuracy: %.2f", tst_loss, tst_acc[-1], best_acc[-1]) 
        logger.info('---------------------------------------------------------------------')
        logger.info(self.cfg.dss_args.type)
        logger.info('---------------------------------------------------------------------')
        """
        ################################################# Final Results Logging #################################################
        """
        if "val_acc" in print_args:
            val_str = "Validation Accuracy: "
            for val in val_acc:
                if val_str == "Validation Accuracy: ":
                    val_str = val_str + str(val)
                else:
                    val_str = val_str + " , " + str(val)
            logger.info(val_str)
        if "tst_acc" in print_args:
            tst_str = "Test Accuracy: "
            for tst in tst_acc:
                if tst_str == "Test Accuracy: ":
                    tst_str = tst_str + str(tst)
                else:
                    tst_str = tst_str + " , " + str(tst)
            logger.info(tst_str)
            tst_str = "Best Accuracy: "
            for tst in best_acc:
                if tst_str == "Best Accuracy: ":
                    tst_str = tst_str + str(tst)
                else:
                    tst_str = tst_str + " , " + str(tst)
            logger.info(tst_str)
        if "time" in print_args:
            time_str = "Time: "
            for t in timing:
                if time_str == "Time: ":
                    time_str = time_str + str(t)
                else:
                    time_str = time_str + " , " + str(t)
            logger.info(time_str)
        omp_timing = np.array(timing)
        # omp_cum_timing = list(self.generate_cumulative_timing(omp_timing))
        logger.info("Total time taken by %s = %.4f ", self.cfg.dss_args.type, omp_timing[-1])

        return trn_acc, val_acc, tst_acc, best_acc, omp_timing

def get_args():
    parser = argparse.ArgumentParser(description="Train a model using CORDS with command-line arguments.")

    # method
    parser.add_argument("--method", type=str, default="gradmatch", help="Method.")

    # Dataset settings
    parser.add_argument("--dataset_name", type=str, default="cifar10", help="Dataset name.")
    parser.add_argument("--datadir", type=str, default="../data", help="Path to dataset directory.")
    parser.add_argument("--feature", type=str, default="dss", help="Feature type.")
    parser.add_argument("--data_type", type=str, default="image", help="Dataset type (e.g., image).")

    # Dataloader settings
    parser.add_argument("--shuffle", type=bool, default=True, help="Whether to shuffle the dataset.")
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size.")
    parser.add_argument("--pin_memory", type=bool, default=True, help="Use pin memory in DataLoader.")

    # Model settings
    parser.add_argument("--architecture", type=str, default="ResNet18", help="Model architecture.")
    parser.add_argument("--num_classes", type=int, default=10, help="Number of classes.")

    # Checkpoint settings
    parser.add_argument("--is_load_ckpt", type=bool, default=False, help="Load checkpoint.")
    parser.add_argument("--is_save_ckpt", type=bool, default=True, help="Save checkpoint.")
    parser.add_argument("--ckpt_dir", type=str, default="results/", help="Checkpoint directory.")
    parser.add_argument("--save_every", type=int, default=20, help="Save checkpoint every n epochs.")

    # Loss function settings
    parser.add_argument("--loss_type", type=str, default="CrossEntropyLoss", help="Loss function type.")
    parser.add_argument("--use_sigmoid", type=bool, default=False, help="Use sigmoid in loss function.")

    # Optimizer settings
    parser.add_argument("--optimizer", type=str, default="sgd", help="Optimizer type.")
    parser.add_argument("--momentum", type=float, default=0.9, help="Momentum for SGD.")
    parser.add_argument("--lr", type=float, default=0.025, help="Learning rate.")
    parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight decay.")
    parser.add_argument("--nesterov", type=bool, default=True, help="Use Nesterov momentum.")

    # Scheduler settings
    parser.add_argument("--scheduler", type=str, default="cosine_annealing", help="Scheduler type.")
    parser.add_argument("--T_max", type=int, default=300, help="Maximum iterations for scheduler.")
    parser.add_argument("--stepsize", type=int, default=20, help="Step size for scheduler.")
    parser.add_argument("--gamma", type=float, default=0.1, help="Learning rate decay factor.")
    parser.add_argument("--T_0", type=int, default=100, help="Initial period for CosineAnnealingWarmRestarts.")
    parser.add_argument("--T_mult", type=int, default=2, help="Cycle length multiplier.")
    parser.add_argument("--eta_min", type=float, default=0.0, help="Min learning rate for cosine annealing.")

    # Data subset selection (DSS) settings
    parser.add_argument("--dss_type", type=str, default="GradMatchPB", help="Data selection strategy.")
    parser.add_argument("--fraction", type=float, default=0.2, help="Fraction of dataset to select.")
    parser.add_argument("--select_every", type=int, default=20, help="Selection frequency.")
    parser.add_argument("--lam", type=float, default=0.5, help="Lambda hyperparameter.")
    parser.add_argument("--selection_type", type=str, default="PerBatch", help="Selection type.")
    parser.add_argument("--v1", type=bool, default=True, help="Enable version 1 of DSS.")
    parser.add_argument("--valid", type=bool, default=False, help="Enable validation selection.")
    parser.add_argument("--eps", type=float, default=1e-100, help="Epsilon for stability.")
    parser.add_argument("--linear_layer", type=bool, default=True, help="Use linear layer.")
    parser.add_argument("--kappa", type=float, default=0, help="Kappa hyperparameter.")
    parser.add_argument("--record_gradient", type=bool, default=False, help="Record gradients.")
    parser.add_argument("--record_200th", type=bool, default=False, help="Record gradients of 200th.")
    parser.add_argument("--acc_loss", type=bool, default=False, help="Record acc_loss.")

    # File paths for saving
    parser.add_argument("--save_place", type=str, default="/root/cords_project_retry/cords-main/benchmarks/SL/results/cifar10/grad_diff_epoch_10%_lam=0_new.csv", help="Path to save training results.")
    parser.add_argument("--save_place_200th", type=str, default="/root/cords_project_retry/cords-main/benchmarks/SL/results/cifar10/grad_diff_in_200th_epoch_10%_lam=0_new.csv", help="Path to save 200th epoch results.")

    # Training settings
    parser.add_argument("--num_epochs", type=int, default=300, help="Number of training epochs.")
    parser.add_argument("--device", type=str, default="cuda", help="Device to use for training.")
    parser.add_argument("--print_every", type=int, default=20, help="Print frequency.")
    parser.add_argument("--run", type=int, default=1, help="Run index.")
    parser.add_argument("--wandb", type=bool, default=False, help="Use Weights & Biases.")
    parser.add_argument("--results_dir", type=str, default="results/", help="Results directory.")
    
    # 解析参数
    args = parser.parse_args()
    return args

def main():
    args = get_args()

    # 构造 config 字典
    config = dict(
        setting="SL",
        is_reg=False,
        method=args.method,

        dataset=dict(name=args.dataset_name,
                     datadir=args.datadir,
                     feature=args.feature,
                     type=args.data_type),

        dataloader=dict(shuffle=args.shuffle,
                        batch_size=args.batch_size,
                        pin_memory=args.pin_memory),

        model=dict(architecture=args.architecture,
                   type='pre-defined',
                   numclasses=args.num_classes),

        ckpt=dict(is_load=args.is_load_ckpt,
                  is_save=args.is_save_ckpt,
                  dir=args.ckpt_dir,
                  save_every=args.save_every),

        loss=dict(type=args.loss_type,
                  use_sigmoid=args.use_sigmoid),

        optimizer=dict(type=args.optimizer,
                       momentum=args.momentum,
                       lr=args.lr,
                       weight_decay=args.weight_decay,
                       nesterov=args.nesterov),

        scheduler=dict(type=args.scheduler,
                       T_max=args.T_max,
                       stepsize=args.stepsize,
                       gamma=args.gamma,
                       T_0=args.T_0,
                       T_mult=args.T_mult,
                       eta_min=args.eta_min),

        dss_args=dict(type=args.dss_type,
                      fraction=args.fraction,
                      select_every=args.select_every,
                      lam=args.lam,
                      selection_type=args.selection_type,
                      v1=args.v1,
                      valid=args.valid,
                      eps=args.eps,
                      linear_layer=args.linear_layer,
                      kappa=args.kappa,
                      collate_fn=None,
                      record_gradiant=args.record_gradient,
                      save_place=args.save_place),

        record_200th=args.record_200th,
        acc_loss=args.acc_loss,

        save_place_200th=args.save_place_200th,
        save_place=args.save_place,

        train_args=dict(num_epochs=args.num_epochs,
                        device=args.device,
                        print_every=args.print_every,
                        run=args.run,
                        wandb=args.wandb,
                        results_dir=args.results_dir,
                        print_args=["trn_loss", "trn_acc", "val_loss", "val_acc", "tst_loss", "tst_acc", "subtrn_acc", "subtrn_loss", "time"],
                        return_args=[])
    )
    config = DotMap(config)
    gradmatch_trn = TrainClassifier(config)
    gradmatch_trn.train()

if __name__ == "__main__":
    main()


