import numpy as np

MMD_W = None
CTR_W = None

def set_global_weight(mmd_w, ctr_w):
    global MMD_W, CTR_W
    MMD_W = mmd_w
    CTR_W = ctr_w
    print(f"setting weights [MMD : {mmd_w}, CTR : {ctr_w}")


## The cuurent hyper-parameters values are not necessarily the best ones for a specific risk.
def get_hparams_class(dataset_name):
    """Return the algorithm class with the given name."""
    if dataset_name not in globals():
        raise NotImplementedError("Dataset not found: {}".format(dataset_name))
    return globals()[dataset_name]

class HAR():
    def __init__(self):
        super(HAR, self).__init__()
        self.train_params = {
            'batch_size': 64,
            'weight_decay': 1e-4,
            'step_size': 50,
            'lr_decay': 0.5,
        }
        self.base_params = {
            'num_steps': 5000,
            'eval_interval': 500,
            # optimizer and scheduler
            'rampup_length': 20000,
            'rampup_coef': 30.0,
            'weight_decay': 5e-4, ## = lr_decay
            'gamma': 1e-4, 
            'warm_steps': 250, ## baselines = 500           
        }
        self.alg_hparams = {
            'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
            'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            'LABELED_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            "MoSSDA": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_source": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all_ablation": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            'MCD': {'learning_rate': 1e-2, 'src_cls_loss_wt': 9.74, 'domain_loss_wt': 5.43},
            ### Baselines
            'CDAC': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'topk': 5, 'threshold': 0.95, 'temp': 0.05},
            'PAC' : {'pre_lr':0.001, 'pre_lr_f':1.0, 'pre_multi':0.1, 'pre_temp':0.05, 'cls_normalize':True, 'cls_bias':False,
                     'learning_rate': 0.01, 'lr_f': 0.001, 'multi': 0.001, 'temp': 0.05, 'cls_layers': '', 'cons_wt': 1., 'cons_threshold': 0.9},
            'AdaMatch': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'DST': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'UniSSDA': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'CLDA':{'learning_rate': 0.001, 'lr_f': 0.01, 'multi': 0.1, 'temperature':0.5, 'contrastive_weight':0.5, 'mmd_weight':0.5},
        }

class EEG():
    def __init__(self):
        super(EEG, self).__init__()
        self.train_params = {
            'batch_size': 128,
            'weight_decay': 1e-4,
            'step_size': 50,##-240901추가
            'lr_decay': 0.5, ##-240901추가            
        }
        self.base_params = {
            'num_steps': 5000,
            'eval_interval': 500,
            # optimizer and scheduler
            'rampup_length': 20000,
            'rampup_coef': 30.0,
            'weight_decay': 5e-4, ## = lr_decay
            'gamma': 1e-4,            
            'warm_steps': 250, ## baselines = 500      
        }
        self.alg_hparams = {
            'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
            'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            'LABELED_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            "MoSSDA": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_source": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all_ablation": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            'MCD': {'learning_rate': 1e-2, 'src_cls_loss_wt': 9.74, 'domain_loss_wt': 5.43},
            ### Baselines
            'CDAC': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'topk': 5, 'threshold': 0.95, 'temp': 0.05},
            'PAC' : {'pre_lr':0.001, 'pre_lr_f':1.0, 'pre_multi':0.1, 'pre_temp':0.05, 'cls_normalize':True, 'cls_bias':False,
                     'learning_rate': 0.01, 'lr_f': 0.001, 'multi': 0.001, 'temp': 0.05, 'cls_layers': '', 'cons_wt': 1., 'cons_threshold': 0.9},
            'AdaMatch': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'DST': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'UniSSDA': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'CLDA':{'learning_rate': 0.001, 'lr_f': 0.01, 'multi': 0.1, 'temperature':0.5, 'contrastive_weight':0.5, 'mmd_weight':0.5},
        }


class WISDM():
    def __init__(self):
        super(WISDM, self).__init__()
        self.train_params = {
            'batch_size': 64,
            'weight_decay': 1e-4,
            'step_size': 50,##-240901추가
            'lr_decay': 0.5, ##-240901추가,            
        }
        self.base_params = {
            'num_steps': 5000,
            'eval_interval': 500,
            # optimizer and scheduler
            'rampup_length': 20000,
            'rampup_coef': 30.0,
            'weight_decay': 5e-4, ## = lr_decay
            'gamma': 1e-4,            
            'warm_steps': 250, ## baselines = 500      
        }
        self.alg_hparams = {
            'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
            'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            'LABELED_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            "MoSSDA": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },            
            "MoSSDA_source": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all_ablation": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            'MCD': {'learning_rate': 1e-2, 'src_cls_loss_wt': 9.74, 'domain_loss_wt': 5.43},
            ### Baselines
            'CDAC': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'topk': 5, 'threshold': 0.95, 'temp': 0.05},
            'PAC' : {'pre_lr':0.001, 'pre_lr_f':1.0, 'pre_multi':0.1, 'pre_temp':0.05, 'cls_normalize':True, 'cls_bias':False,
                     'learning_rate': 0.01, 'lr_f': 0.001, 'multi': 0.001, 'temp': 0.05, 'cls_layers': '', 'cons_wt': 1., 'cons_threshold': 0.9},
            'AdaMatch': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'DST': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'UniSSDA': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'CLDA':{'learning_rate': 0.001, 'lr_f': 0.01, 'multi': 0.1, 'temperature':0.5, 'contrastive_weight':0.5, 'mmd_weight':0.5},
        }

class HHAR():
    def __init__(self):
        super().__init__()
        self.train_params = {
            'batch_size': 64,
            'weight_decay': 1e-4,
            'step_size': 50,##-240901추가
            'lr_decay': 0.5, ##-240901추가,            
        }
        self.base_params = {
            'num_steps': 5000,
            'eval_interval': 500,
            # optimizer and scheduler
            'rampup_length': 20000,
            'rampup_coef': 30.0,
            'weight_decay': 5e-4, ## = lr_decay
            'gamma': 1e-4,            
            'warm_steps': 250, ## baselines = 500      
        }
        self.alg_hparams = {
            'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
            'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            'LABELED_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            "MoSSDA": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },            
            "MoSSDA_source": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all_ablation": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            'MCD': {'learning_rate': 1e-2, 'src_cls_loss_wt': 9.74, 'domain_loss_wt': 5.43},
            ### Baselines
            'CDAC': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'topk': 5, 'threshold': 0.95, 'temp': 0.05},
            'PAC' : {'pre_lr':0.001, 'pre_lr_f':1.0, 'pre_multi':0.1, 'pre_temp':0.05, 'cls_normalize':True, 'cls_bias':False,
                     'learning_rate': 0.01, 'lr_f': 0.001, 'multi': 0.001, 'temp': 0.05, 'cls_layers': '', 'cons_wt': 1., 'cons_threshold': 0.9},
            'AdaMatch': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'DST': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'UniSSDA': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'CLDA':{'learning_rate': 0.001, 'lr_f': 0.01, 'multi': 0.1, 'temperature':0.5, 'contrastive_weight':0.5, 'mmd_weight':0.5},
        }

class MFD():
    def __init__(self):
        super().__init__()
        self.train_params = {
            'batch_size': 64,
            'weight_decay': 1e-4,
            'step_size': 50,##-240901추가
            'lr_decay': 0.5, ##-240901추가,            
        }        
        self.base_params = {
            'num_steps': 5000,
            'eval_interval': 500,
            # optimizer and scheduler
            'rampup_length': 20000,
            'rampup_coef': 30.0,
            'weight_decay': 5e-4, ## = lr_decay
            'gamma': 1e-4,            
            'warm_steps': 250, ## baselines = 500      
        }
        self.alg_hparams = {
            'NO_ADAPT': {'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
            'TARGET_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            'LABELED_ONLY': {'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            "MoSSDA": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_source": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all_ablation": {
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            'MCD': {'learning_rate': 1e-2, 'src_cls_loss_wt': 9.74, 'domain_loss_wt': 5.43},
            ### Baselines
            'CDAC': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'topk': 5, 'threshold': 0.95, 'temp': 0.05},
            'PAC' : {'pre_lr':0.001, 'pre_lr_f':1.0, 'pre_multi':0.1, 'pre_temp':0.05, 'cls_normalize':True, 'cls_bias':False,
                     'learning_rate': 0.01, 'lr_f': 0.001, 'multi': 0.001, 'temp': 0.05, 'cls_layers': '', 'cons_wt': 1., 'cons_threshold': 0.9},
            'AdaMatch': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'DST': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'UniSSDA': {'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'CLDA':{'learning_rate': 0.001, 'lr_f': 0.01, 'multi': 0.1, 'temperature':0.5, 'contrastive_weight':0.5, 'mmd_weight':0.5},
            
        }

class PTBXL():
    def __init__(self):
        super(PTBXL, self).__init__()
        self.train_params = {
            # 'batch_size': 128, ##  32 -> 128 수정 (241008) ==> Out of memory in PAC
            'weight_decay': 1e-4,
            'step_size': 50,
            'lr_decay': 0.5,
            
        }
        self.base_params = {
            'num_steps': 5000,
            'eval_interval': 500,
            # optimizer and scheduler
            'rampup_length': 20000,
            'rampup_coef': 30.0,
            'weight_decay': 5e-4, ## = lr_decay
            'gamma': 1e-4,            
            'warm_steps': 250, ## baselines = 500      
        }
        self.alg_hparams = {
            'NO_ADAPT': {'batch_size': 128,'learning_rate': 1e-3, 'src_cls_loss_wt': 1},
            'TARGET_ONLY': {'batch_size': 128,'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            'LABELED_ONLY': {'batch_size': 128,'learning_rate': 1e-3, 'trg_cls_loss_wt': 1},
            "MoSSDA": {
                'batch_size': 128,
                "learning_rate": 0.001,  # Adjust learning rate (241008 0.001->0.05)
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_source": {
                'batch_size': 128,
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all": {
                'batch_size': 128,
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            "MoSSDA_all_ablation": {
                'batch_size': 128,
                "learning_rate": 0.001,  # Adjust learning rate
                "mmd_weight": MMD_W,  # Weight for MMD loss
                "ctr_weight": CTR_W,  # Weight for contrastive loss
                "projection_dim": 128  # Dimension for projection head
            },
            'MCD': {'batch_size': 128,'learning_rate': 1e-2, 'src_cls_loss_wt': 9.74, 'domain_loss_wt': 5.43},
            ### Baselines
            'CDAC': {'batch_size': 128,'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'topk': 5, 'threshold': 0.95, 'temp': 0.05},
            'PAC' : {'batch_size': 100,'pre_lr':0.001, 'pre_lr_f':1.0, 'pre_multi':0.1, 'pre_temp':0.05, 'cls_normalize':True, 'cls_bias':False,
                     'learning_rate': 0.01, 'lr_f': 0.001, 'multi': 0.001, 'temp': 0.05, 'cls_layers': '', 'cons_wt': 1., 'cons_threshold': 0.9},
            'AdaMatch': {'batch_size': 128,'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'DST': {'batch_size': 128,'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'UniSSDA': {'batch_size': 128,'learning_rate': 0.001, 'lr_f': 1.0, 'multi': 0.1, 'temp': 0.05, 'tau': 0.9},
            'CLDA':{'batch_size': 128,'learning_rate': 0.001, 'lr_f': 0.01, 'multi': 0.1, 'temperature':0.5, 'contrastive_weight':0.5, 'mmd_weight':0.5},
            
        }