import argparse
import datetime

import h5py
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn

import os
import models
from torch.utils.data import TensorDataset
from pathlib import Path
from collections import OrderedDict

from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy
from timm.utils import ModelEma
from typing import List, Dict, Optional



class EpochDataset(torch.utils.data.Dataset):
    """
    A custom PyTorch Dataset class for loading epoch-based data from an HDF5 file.

    This dataset supports filtering by specific events and maps event names to integer labels
    using a user-provided dictionary. It also allows selecting specific EEG channels for processing.
    """
    def __init__(self, h5_path: str, events: Dict[str, int], channels_select: Optional[List[str]] = None,
                 extra_info: Optional[List[str]] = None):
        """
        Initialize the EpochDataset.

        Parameters
        ----------
        h5_path : str
            Path to the HDF5 file containing the processed epochs.
        events : dict
            A dictionary mapping event names to integer labels (e.g., {'classA': 0, 'classB': 1}).
        channels_select : list of str, optional
            List of EEG channel names to select and reorder. Default is None.
        extra_info : list of str, optional
            List of additional metadata attributes or keys from the 'info' dataset. Default is None.
        """
        if not isinstance(events, dict):
            raise ValueError(
                "Parameter 'events' must be a dictionary with event names as keys and integer labels as values.")

        if not os.path.exists(h5_path):
            raise FileNotFoundError(f"HDF5 file not found at the specified path: {h5_path}")

        self.h5_path = h5_path
        self.events = events
        self.channels_select = channels_select
        self.extra_info = extra_info if extra_info is not None else []
        self.index = []

        print(f"Loading HDF5 data from {h5_path}")

        # Build index by reading the HDF5 structure
        with h5py.File(h5_path, 'r') as h5_file:
            for group_name in h5_file.keys():
                group = h5_file[group_name]
                for event_name in self.events.keys():
                    if event_name in group:
                        event_data = group[event_name]
                        num_epochs = event_data.shape[0]
                        label = self.events[event_name]
                        for epoch_idx in range(num_epochs):
                            self.index.append((group_name, event_name, epoch_idx, label))

        self.total_epochs = len(self.index)

    def __len__(self):
        """
        Return the total number of epochs in the dataset.

        Returns
        -------
        int
            The total number of epochs.
        """
        return self.total_epochs

    def __getitem__(self, idx: int):
        """
        Retrieve the data, label, and optional metadata for a given epoch index.

        Parameters
        ----------
        idx : int
            The index of the epoch to retrieve.

        Returns
        -------
        tuple
            A tuple containing:
            - data (torch.Tensor): The epoch data as a tensor, with selected channels.
            - label (int): The integer label corresponding to the class of the epoch.
            - info (dict or str): Additional metadata (if specified) or an empty string.
        """
        if idx >= len(self):
            raise IndexError("Index out of range")

        group_name, event_name, epoch_idx, label = self.index[idx]

        # Open the HDF5 file and retrieve the epoch data
        with h5py.File(self.h5_path, 'r') as h5_file:
            group = h5_file[group_name]
            epoch_data = group[event_name][epoch_idx]

            # If channels_select is specified, reorder and select the specified channels
            if self.channels_select:
                ch_names = group[event_name].attrs['chOrder']
                ch_idx = [np.where(ch_names == ch)[0][0] for ch in self.channels_select]
                epoch_data = epoch_data[ch_idx, :800]  # Select the channels and retain time points

            epoch_data = epoch_data[:, :800]

            # Retrieve additional info if required
            info = {}
            if self.extra_info:
                for col in self.extra_info:
                    # First check in the dataset attributes
                    if col in group[event_name].attrs:
                        info[col] = group[event_name].attrs[col]
                    # If not found in attributes, check in the info dataset
                    elif col in group.get('info', {}).attrs:
                        info[col] = group['info'].attrs[col]
                    else:
                        # If not found, print a warning
                        print(f"Warning: '{col}' not found in dataset attributes or info.")
                # If extra_info is not found, provide empty dictionary
                if not info:
                    info = {}

        # Convert data to PyTorch tensor
        data = torch.tensor(epoch_data, dtype=torch.float32)

        return data, label, info


def get_args():
    parser = argparse.ArgumentParser('fine-tuning and evaluation script for EEG classification', add_help=False)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--epochs', default=50, type=int)
    parser.add_argument('--update_freq', default=5, type=int)
    parser.add_argument('--save_ckpt_freq', default=1, type=int)

    # robust evaluation
    parser.add_argument('--robust_test', default=None, type=str,
                        help='robust evaluation dataset')

    # Model parameters
    parser.add_argument('--model', default='model_5M', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--qkv_bias', action='store_true')
    parser.add_argument('--disable_qkv_bias', action='store_false', dest='qkv_bias')
    parser.set_defaults(qkv_bias=True)
    parser.add_argument('--rel_pos_bias', action='store_true')
    parser.add_argument('--disable_rel_pos_bias', action='store_false', dest='rel_pos_bias')
    parser.set_defaults(rel_pos_bias=True)
    parser.add_argument('--abs_pos_emb', action='store_true')
    parser.set_defaults(abs_pos_emb=True)
    parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
                        help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")

    parser.add_argument('--input_size', default=800, type=int,
                        help='EEG input size')

    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
                        help='Attention dropout rate (default: 0.)')
    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')

    parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
    parser.add_argument('--model_ema', action='store_true', default=False)
    parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
    parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')

    # Optimizer parameters
    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "adamw"')
    parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')
    parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
        weight decay. We use a cosine schedule for WD and using a larger decay by
        the end of training improves performance for ViTs.""")

    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
                        help='learning rate (default: 5e-4)')
    parser.add_argument('--layer_decay', type=float, default=0.9)

    parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
                        help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')

    parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
                        help='num of steps to warmup LR, will overload warmup_epochs if set > 0')

    parser.add_argument('--smoothing', type=float, default=0.1,
                        help='Label smoothing (default: 0.1)')

    # * Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')

    # * Finetuning params
    parser.add_argument('--finetune', default='',
                        help='finetune from checkpoint')
    parser.add_argument('--model_key', default='model|module', type=str)
    parser.add_argument('--model_prefix', default='', type=str)
    parser.add_argument('--model_filter_name', default='gzp', type=str)
    parser.add_argument('--init_scale', default=0.001, type=float)
    parser.add_argument('--use_mean_pooling', action='store_true')
    parser.set_defaults(use_mean_pooling=True)
    parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
    parser.add_argument('--disable_weight_decay_on_rel_pos_bias', action='store_true', default=False)

    # Dataset parameters
    parser.add_argument('--nb_classes', default=0, type=int,
                        help='number of the classification types')

    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default=None,
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda:0',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')
    parser.add_argument('--auto_resume', action='store_true')
    parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
    parser.set_defaults(auto_resume=False)

    parser.add_argument('--save_ckpt', action='store_true')
    parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
    parser.set_defaults(save_ckpt=False)

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true',
                        help='Perform evaluation only')
    parser.add_argument('--dist_eval', action='store_true', default=False,
                        help='Enabling distributed evaluation')
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

    parser.add_argument('--enable_deepspeed', action='store_true', default=False)
    parser.add_argument('--dataset', default='none', type=str,
                        help='dataset: TUAB | TUEV')

    known_args, _ = parser.parse_known_args()

    if known_args.enable_deepspeed:
        try:
            import deepspeed
            parser = deepspeed.add_config_arguments(parser)
            ds_init = deepspeed.initialize
        except:
            print("Please 'pip install deepspeed==0.4.0'")
            exit(0)
    else:
        ds_init = None

    return parser.parse_args(), ds_init


def get_models(args):
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        attn_drop_rate=args.attn_drop_rate,
        drop_block_rate=None,
        use_mean_pooling=args.use_mean_pooling,
        init_scale=args.init_scale,
        use_rel_pos_bias=args.rel_pos_bias,
        use_abs_pos_emb=args.abs_pos_emb,
        init_values=args.layer_scale_init_value,
        qkv_bias=args.qkv_bias,
    )

    return model


def load_data(data_path, event_dict, fold):
    fold = fold - 1
    all_data, all_labels, num_samples, updated_channel_list = heterogeneous_data_input(data_path, event_dict)

    fold_sizes = np.full(5, num_samples // 5, dtype=int)
    fold_sizes[:num_samples % 5] += 1
    current = 0
    folds = []
    for fold_size in fold_sizes:
        folds.append(np.arange(current, current + fold_size))
        current += fold_size

    test_fold = fold
    val_fold = (fold + 1) % 5
    train_folds = [i for i in range(5) if i != test_fold and i != val_fold]

    train_indices = np.concatenate([folds[i] for i in train_folds])
    val_indices = folds[val_fold]
    test_indices = folds[test_fold]
    train_data = all_data[train_indices]
    train_label = all_labels[train_indices]
    validation_data = all_data[val_indices]
    validation_label = all_labels[val_indices]
    test_data = all_data[test_indices]
    test_label = all_labels[test_indices]

    # Convert data and labels to PyTorch tensors and create TensorDatasets
    train_dataset = TensorDataset(torch.from_numpy(train_data).float(), torch.from_numpy(train_label))
    val_dataset = TensorDataset(torch.from_numpy(validation_data).float(),
                                torch.from_numpy(validation_label))
    test_dataset = TensorDataset(torch.from_numpy(test_data).float(), torch.from_numpy(test_label))

    ch_names = updated_channel_list
    ch_names = [name.split(' ')[-1].split('-')[0] for name in ch_names]
    return train_dataset, val_dataset, test_dataset, ch_names


def heterogeneous_data_input(data_path, event_dict):
    dataset = EpochDataset(data_path, event_dict, channels_select=None,
                                     extra_info=['chOrder'])
    data, label, info = dataset[0]
    channel_list = info.get("chOrder").tolist()
    dataset_name = data_path.lower()
    match True:
        case _ if 'bcic_iv_2b' in dataset_name:
            format_channel_list = ['EEG:C3', 'EEG:Cz', 'EEG:C4', 'EOG:ch01', 'EOG:ch02', 'EOG:ch03']
            print("Dataset detected: BCIC IV 2b")
        case _ if 'figshare_shudb' in dataset_name:
            format_channel_list = [
                'EEG:Fp1', 'EEG:Fp2', 'EEG:Fz', 'EEG:F3', 'EEG:F4', 'EEG:F7', 'EEG:F8',
                'EEG:FC1', 'EEG:FC2', 'EEG:FC5', 'EEG:FC6', 'EEG:Cz', 'EEG:C3', 'EEG:C4',
                'EEG:T3', 'EEG:T4', 'EEG:A1', 'EEG:A2', 'EEG:CP1', 'EEG:CP2', 'EEG:CP5',
                'EEG:CP6', 'EEG:Pz', 'EEG:P3', 'EEG:P4', 'EEG:T5', 'EEG:T6', 'EEG:PO3',
                'EEG:PO4', 'EEG:Oz', 'EEG:O1', 'EEG:O2'
            ]
            print("Dataset detected: Figshare SHUDB")
        case _:
            raise ValueError("Unknown dataset format in data_path.")
    modified_channel_list = [channel for channel in format_channel_list if channel in models.standard_1020_format]
    channel_mapping = dict(zip(format_channel_list, channel_list))
    original_corresponding_list = [channel_mapping[channel] for channel in modified_channel_list]
    channel_mapping_standard = dict(zip(models.standard_1020_format, models.standard_1020))
    updated_channel_list = [channel_mapping_standard[channel] for channel in modified_channel_list]
    dataset = EpochDataset(data_path, event_dict, channels_select=original_corresponding_list,
                                     extra_info=['chOrder'])
    # Extract all data, labels, and additional info from the dataset
    all_data = []
    all_labels = []
    all_info = []
    for data, label, info in dataset:
        all_data.append(data[..., :800])
        all_labels.append(label)
        all_info.append(info)
    # Convert data and labels to NumPy arrays
    all_data = np.stack(all_data)  # Shape: (num_samples, num_channels, ...)
    all_labels = np.stack(all_labels)  # Shape: (num_samples,)
    num_samples = all_data.shape[0]
    indices = np.random.RandomState(seed=42).permutation(num_samples)
    all_data = all_data[indices]
    all_labels = all_labels[indices]
    return all_data, all_labels, num_samples, updated_channel_list


def get_dataset(args, fold):
    if args.dataset == 'bcic_iv_2b':
        folder = r''
        data_path = folder + 'bcic_iv_2b_combined.hdf5'
        event_dict = {
            "Cue onset left (class 1)": 0,
            "Cue onset right (class 2)": 1,
        }
        args.nb_classes = 2
        args.batch_size = 64


    elif args.dataset == 'figshare_shudb':
        folder = r''
        data_path = folder + 'figshare_shudb_combined.hdf5'
        event_dict = {
            "left": 0,
            "right": 1,
        }
        args.nb_classes = 2
        args.batch_size = 64

    train_dataset, test_dataset, val_dataset, ch_names = load_data(data_path, event_dict, fold)

    return train_dataset, test_dataset, val_dataset, ch_names


def model_loader(args, device):
    model = get_models(args)
    model.to(device)
    patch_size = model.patch_size
    print("Patch size = %s" % str(patch_size))
    args.window_size = (1, args.input_size // patch_size)
    args.patch_size = patch_size
    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.finetune, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.finetune, map_location='cpu')

        print("Load ckpt from %s" % args.finetune)
        checkpoint_model = None
        for model_key in args.model_key.split('|'):
            if model_key in checkpoint:
                checkpoint_model = checkpoint[model_key]
                print("Load state_dict by model_key = %s" % model_key)
                break
        if checkpoint_model is None:
            checkpoint_model = checkpoint
        if (checkpoint_model is not None) and (args.model_filter_name != ''):
            all_keys = list(checkpoint_model.keys())
            new_dict = OrderedDict()
            for key in all_keys:
                if key.startswith('student.'):
                    new_dict[key[8:]] = checkpoint_model[key]
                else:
                    pass
            checkpoint_model = new_dict

        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        all_keys = list(checkpoint_model.keys())
        for key in all_keys:
            if "relative_position_index" in key:
                checkpoint_model.pop(key)

        models.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
    model.to(device)
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '',
            resume='')
        print("Using EMA with decay = %.8f" % args.model_ema_decay)
    model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return model, model_ema, model_without_ddp, n_parameters


def creat_dataloader(args, fold):
    dataset_train, dataset_test, dataset_val, ch_names = get_dataset(args, fold)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=32,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
        shuffle=True
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=32,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False,
        shuffle=True)
    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=32,
                                                   num_workers=args.num_workers,
                                                   pin_memory=args.pin_mem,
                                                   drop_last=False,
                                                   shuffle=True)
    return ch_names, data_loader_test, data_loader_train, data_loader_val, dataset_train


def configure_components(args, ch_names, data_loader_test, device, ds_init, metrics, model, model_ema,
                         model_without_ddp, num_layers, num_training_steps_per_epoch):
    if args.layer_decay < 1.0:
        assigner = models.LayerDecayValueAssigner(
            list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
    else:
        assigner = None
    if assigner is not None:
        print("Assigned values = %s" % str(assigner.values))
    skip_weight_decay_list = model.no_weight_decay()
    if args.disable_weight_decay_on_rel_pos_bias:
        for i in range(num_layers):
            skip_weight_decay_list.add("blocks.%d.attn.relative_position_bias_table" % i)
    if args.enable_deepspeed:
        loss_scaler = None
        optimizer_params = model.get_parameter_groups(
            model, args.weight_decay, skip_weight_decay_list,
            assigner.get_layer_id if assigner is not None else None,
            assigner.get_scale if assigner is not None else None)
        model, optimizer, _, _ = ds_init(
            args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
        )

        print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
        assert model.gradient_accumulation_steps() == args.update_freq
    else:
        if args.distributed:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],
                                                              find_unused_parameters=True)
            model_without_ddp = model.module

        optimizer = models.create_optimizer(
            args, model_without_ddp, skip_list=skip_weight_decay_list,
            get_num_layer=assigner.get_layer_id if assigner is not None else None,
            get_layer_scale=assigner.get_scale if assigner is not None else None)
        loss_scaler = models.NativeScalerWithGradNormCount()
    print("Use step level LR scheduler!")
    lr_schedule_values = models.cosine_scheduler(
        args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
        warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
    )
    if args.weight_decay_end is None:
        args.weight_decay_end = args.weight_decay
    wd_schedule_values = models.cosine_scheduler(
        args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
    print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
    if args.nb_classes == 1:
        criterion = torch.nn.BCEWithLogitsLoss()
    elif args.smoothing > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()
    print("criterion = %s" % str(criterion))
    models.auto_load_model(
        args=args, model=model, model_without_ddp=model_without_ddp,
        optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
    if args.eval:
        balanced_accuracy = []
        accuracy = []
        for data_loader in data_loader_test:
            test_stats = model.evaluate(data_loader, model, device, header='Test:', ch_names=ch_names,
                                        metrics=metrics,
                                        is_binary=(args.nb_classes == 1))
            accuracy.append(test_stats['accuracy'])
            balanced_accuracy.append(test_stats['balanced_accuracy'])
        print(
            f"======Accuracy: {np.mean(accuracy)} {np.std(accuracy)}, balanced accuracy: {np.mean(balanced_accuracy)} {np.std(balanced_accuracy)}")
        exit(0)
    return criterion, loss_scaler, lr_schedule_values, model, model_without_ddp, optimizer, wd_schedule_values


def epoch_training(args, best_val_loss, ch_names, criterion, data_loader_test, data_loader_train, data_loader_val,
                   device, fold_num, logs, loss_scaler, lr_schedule_values, max_accuracy, max_accuracy_test, metrics,
                   model, model_ema, model_without_ddp, n_parameters, num_training_steps_per_epoch, optimizer, seed,
                   wd_schedule_values):
    for epoch in range(args.start_epoch, args.epochs):

        train_stats = models.train_one_epoch(
            model, criterion, data_loader_train, optimizer,
            device, epoch, loss_scaler, args.clip_grad, model_ema,
            start_steps=epoch * num_training_steps_per_epoch,
            lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
            num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
            ch_names=ch_names, is_binary=args.nb_classes == 1
        )

        if args.output_dir and args.save_ckpt:
            model.save_model(
                args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema, save_ckpt_freq=args.save_ckpt_freq)

        if data_loader_val is not None:
            val_stats = models.evaluate(data_loader_val, model, device, header='Val:', ch_names=ch_names,
                                        metrics=metrics, is_binary=args.nb_classes == 1)
            test_stats = models.evaluate(data_loader_test, model, device, header='Test:', ch_names=ch_names,
                                         metrics=metrics, is_binary=args.nb_classes == 1)

            if val_stats["accuracy"] > max_accuracy:
                max_accuracy = val_stats["accuracy"]

            if val_stats["loss"] < best_val_loss:
                best_val_loss = val_stats["loss"]
                print(f"New best validation loss: {best_val_loss:.4f}, saving best stats...")

                best_val_stats = {'seed': seed, 'fold': fold_num}
                best_val_stats.update(val_stats)

                best_test_stats = {'seed': seed, 'fold': fold_num}
                best_test_stats.update(test_stats)

            if test_stats["accuracy"] > max_accuracy_test:
                max_accuracy_test = test_stats["accuracy"]

            print(f'Max accuracy val: {max_accuracy:}%, max accuracy test: {max_accuracy_test:}%')
            print(f'Test acc: {test_stats["accuracy"]}')

            log_stats = {
                **{f"train_{k}": v for k, v in train_stats.items()},
                **{f"val_{k}": v for k, v in val_stats.items()},
                **{f"test_{k}": v for k, v in test_stats.items()},
                "epoch": epoch,
                "n_parameters": n_parameters,
            }
            logs.append(log_stats)

        else:
            log_stats = {
                **{f"train_{k}": v for k, v in train_stats.items()},
                "epoch": epoch,
                "n_parameters": n_parameters,
            }
            logs.append(log_stats)
    return best_test_stats, best_val_stats


def main(args, ds_init, seed_num, fold_num):
    models.init_distributed_mode(args)

    if ds_init is not None:
        models.create_ds_config(args)

    CSV_save_path = './CSV/' + args.model + '/' + args.dataset + '/'  # csv存储路径
    os.makedirs(os.path.dirname(CSV_save_path), exist_ok=True)

    metrics = ["accuracy", "balanced_accuracy", "cohen_kappa", "f1_weighted", "f1_macro", "f1_micro"]

    device = torch.device(args.device)
    cudnn.benchmark = True
    seeds = list(range(seed_num))
    best_val_stats_list = []
    best_test_stats_list = []

    for seed_idx, seed in enumerate(seeds):
        print(f"Running for seed {seed}")
        torch.manual_seed(seed)
        np.random.seed(seed)
        for fold in range(1, fold_num + 1):
            ch_names, data_loader_test, data_loader_train, data_loader_val, dataset_train = creat_dataloader(args, fold)

            model, model_ema, model_without_ddp, n_parameters = model_loader(args, device)

            print("Model = %s" % str(model_without_ddp))
            print('number of params:', n_parameters)

            total_batch_size = args.batch_size * args.update_freq * models.get_world_size()
            num_training_steps_per_epoch = len(dataset_train) // total_batch_size
            print("LR = %.8f" % args.lr)
            print("Batch size = %d" % total_batch_size)
            print("Update frequent = %d" % args.update_freq)
            print("Number of training examples = %d" % len(dataset_train))
            print("Number of training training per epoch = %d" % num_training_steps_per_epoch)

            num_layers = model_without_ddp.get_num_layers()
            criterion, loss_scaler, lr_schedule_values, model, model_without_ddp, optimizer, wd_schedule_values = configure_components(
                args, ch_names, data_loader_test, device, ds_init, metrics, model, model_ema, model_without_ddp,
                num_layers, num_training_steps_per_epoch)

            print(f"Start training for {args.epochs} epochs")
            start_time = time.time()
            max_accuracy = 0.0
            max_accuracy_test = 0.0
            best_val_loss = float('inf')

            logs = []
            best_test_stats, best_val_stats = epoch_training(args, best_val_loss, ch_names, criterion, data_loader_test,
                                                             data_loader_train, data_loader_val, device, fold_num, logs,
                                                             loss_scaler, lr_schedule_values, max_accuracy,
                                                             max_accuracy_test, metrics, model, model_ema,
                                                             model_without_ddp, n_parameters,
                                                             num_training_steps_per_epoch, optimizer, seed,
                                                             wd_schedule_values)

            total_time = time.time() - start_time
            total_time_str = str(datetime.timedelta(seconds=int(total_time)))
            print('Training time {}'.format(total_time_str))

            import pandas as pd
            df = pd.DataFrame(logs)
            os.makedirs(CSV_save_path, exist_ok=True)
            df.to_csv(f"{CSV_save_path}/all_metrics.csv", index=False)
            print(f"Saved metrics to {CSV_save_path}/all_metrics.csv")

            if best_val_stats is not None:
                best_val_stats_list.append(best_val_stats)

            if best_test_stats is not None:
                best_test_stats_list.append(best_test_stats)

    df_val = pd.DataFrame(best_val_stats_list)

    val_numeric = df_val.drop(columns=['seed', 'fold'], errors='ignore')
    val_mean = val_numeric.mean().to_frame().T
    val_std = val_numeric.std().to_frame().T

    val_mean['seed'] = 'mean'
    val_mean['fold'] = ''
    val_std['seed'] = 'std'
    val_std['fold'] = ''

    df_val_final = pd.concat([df_val, val_mean, val_std], ignore_index=True)

    df_val_final.to_csv(os.path.join(CSV_save_path, 'best_val_stats.csv'), index=False)
    print(f"Saved best validation stats to {os.path.join(CSV_save_path, 'best_val_stats.csv')}")

    df_test = pd.DataFrame(best_test_stats_list)

    test_numeric = df_test.drop(columns=['seed', 'fold'], errors='ignore')
    test_mean = test_numeric.mean().to_frame().T
    test_std = test_numeric.std().to_frame().T

    test_mean['seed'] = 'mean'
    test_mean['fold'] = ''
    test_std['seed'] = 'std'
    test_std['fold'] = ''

    df_test_final = pd.concat([df_test, test_mean, test_std], ignore_index=True)

    df_test_final.to_csv(os.path.join(CSV_save_path, 'best_test_stats.csv'), index=False)
    print(f"Saved best test stats to {os.path.join(CSV_save_path, 'best_test_stats.csv')}")


if __name__ == '__main__':

    datasets = ['bcic_iv_2b', 'figshare_shudb']

    for dataset in datasets:
        if dataset == 'bcic_iv_2b':
            # if True:
            print('Experiment on dataset', dataset)
            opts, ds_init = get_args()
            opts.dataset = dataset
            opts.model = 'tiny'
            opts.finetune = './model/tiny/checkpoint.pth'
            os.environ['MASTER_ADDR'] = 'localhost'
            os.environ['MASTER_PORT'] = '12346'

            if opts.output_dir:
                Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
            main(opts, ds_init, 1, 1)
