# Time : 2023/11/13 12:55
# Author : 小霸奔
# FileName: config.p
ISRUC_DICT = {
    "pretrain_lr": 1e-4,
    "ssl_lr": 1e-6,
    "incremental_lr": 1e-7,
    "best_rand": 4321,
    "batch": 16
}

FACE_DICT = {
    "pretrain_lr": 1e-4,
    "ssl_lr": 1e-6,
    "incremental_lr": 1e-7,
    "best_rand": 432,
    "batch": 28
}

Physionet_DICT = {
    "pretrain_lr": 1e-4,
    "ssl_lr": 1e-6,
    "incremental_lr": 1e-6,
    "best_rand": 4321,
    "batch": 32
}


class ModelConfig(object):
    def __init__(self, dataset):
        self.dataset = dataset
        self.ConvDrop = 0.1
        self.EncoderParam = EncoderConfig()
        self.SleepMlpParam = SleepMlpParam()
        self.FaceMlpParam = FaceMlpParam()
        self.BCI2000MlpParam = BCI2000MlpParam()
        self.BCI2000_2_MlpParam = BCI2000_2_MlpParam()
        self.TUEVMlpParam = TUEVMlpParam()
        self.MDDMlpParam = MDDMlpParam()
        self.NumClasses = 5
        self.ClassNames = ['W', 'N1', 'N2', 'N3', 'REM']
        self.ClassNamesFace = ['Anger',
                               'Disgust',
                               'Fear',
                               'Sadness',
                               'Neutral',
                               'Amusement',
                               'Inspiration',
                               'Joy',
                               'Tenderness']
        self.ClassNamesBCI2000_2 = ['Left', 'Right']
        self.ClassNamesBCI2000 = ['Left', 'Right', 'Fist', 'Feet']
        self.ClassNamesMDD = ['HC', 'MDD']
        self.ClassNamesTUEV = ["spsw", "gped", "pled", "eyem", "artf", "bckg"]
        self.SeqLength = 20
        self.BatchSize = 32
        self.EpochLength = 3000
        self.EpochLengthFace = 7500
        self.EpochLengthBCI2000 = 640
        self.EpochLengthMDD = 1000
        self.EpochLengthTUEV = 1000
        channel_num = self.get_channel_info()

        if self.dataset == "FACE":
            self.FaceCn = channel_num[0]
        elif self.dataset in ['BCI2000', 'BCI2000_2']:
            self.BCICn = channel_num[0]
        elif self.dataset == "MDD":
            self.MDDCn = channel_num[0]
        elif self.dataset == "TUEV":
            self.TUEVCn = channel_num[0]
        else:
            self.EegNum = channel_num[0]
            self.EogNum = channel_num[1]

    def get_channel_info(self):
        if self.dataset == "ISRUC":
            return [6, 2]
        elif self.dataset == "Hang7":
            return [6, 2]
        elif self.dataset == "SleepEDF":
            return [2, 1]
        elif self.dataset == "HMC":
            return [4, 2]
        elif self.dataset == "MASS":
            return [6, 4]
        elif self.dataset == "FACE":
            return [32]
        elif self.dataset in ['BCI2000', 'BCI2000_2']:
            return [64]
        elif self.dataset == "MDD":
            return [20]
        elif self.dataset == "TUEV":
            return [19]


class EncoderConfig(object):
    def __init__(self):
        self.n_head = 8
        self.d_model = 512
        self.layer_num = 3
        self.drop = 0.1


class SleepMlpParam(object):
    def __init__(self):
        self.drop = 0.1
        self.first_linear = [512, 256]
        self.second_linear = [256, 128]
        self.out_linear = [128, 5]


class FaceMlpParam(object):
    def __init__(self):
        self.drop = 0.1
        self.first_linear = [512, 256]
        self.second_linear = [256, 128]
        self.out_linear = [128, 9]


class BCI2000MlpParam(object):
    def __init__(self):
        self.drop = 0.1
        self.first_linear = [512, 256]
        self.second_linear = [256, 128]
        self.out_linear = [128, 4]


class BCI2000_2_MlpParam(object):
    def __init__(self):
        self.drop = 0.1
        self.first_linear = [512, 256]
        self.second_linear = [256, 128]
        self.out_linear = [128, 2]


class MDDMlpParam(object):
    def __init__(self):
        self.drop = 0.1
        self.first_linear = [512, 256]
        self.second_linear = [256, 128]
        self.out_linear = [128, 2]


class TUEVMlpParam(object):
    def __init__(self):
        self.drop = 0.1
        self.first_linear = [512, 256]
        self.second_linear = [256, 128]
        self.out_linear = [128, 6]






