from trainers.train import Trainer_MoSSDA
from configs.data_model_configs import set_global_kernelType, set_global_mixType
from configs.hparams import set_global_weight
import os
import argparse

parser = argparse.ArgumentParser()

import torch
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()

print(device)

if __name__ == "__main__":

    # ========  Experiments Phase ================
    parser.add_argument('--phase',               default='train',         type=str, help='train, test')

    # ========  Experiments Name ================
    parser.add_argument('--save_dir',               default='experiments_logs',         type=str, help='Directory containing all experiments')
    parser.add_argument('--exp_name',               default='EXP1',         type=str, help='experiment name')

    # ========= Select the DA methods ============
    parser.add_argument('--da_method',              default='SSDA',               type=str, help='NO_ADAPT, TARGET_ONLY, MoSSDA_target, MoSSDA_source, MoSSDA_all')

    # ========= Select the DATASET ==============
    parser.add_argument('--data_path',              default=r'./Datasets/UCIHAR',                  type=str, help='Path containing datase2t')
    parser.add_argument('--dataset',                default='HAR',                      type=str, help='Dataset of choice: (WISDM - EEG - HAR - HHAR - PTBXL - MFD)')

    # ========= Select the BACKBONE ==============
    parser.add_argument('--backbone',               default='CNN',                      type=str, help='Backbone of choice: (CNN - RESNET18 - TCN)')

    # ========= Experiment settings ===============
    parser.add_argument('--num_runs',               default=1,                          type=int, help='Number of consecutive run with different seeds')
    parser.add_argument('--device',                 default= "cuda",                   type=str, help='cpu or cuda')
    
    # ======== SSDA Settings =======================
    parser.add_argument('--num_epochs',          default=50,                type=int, help='Number of training epochs for SSDA')
    parser.add_argument('--post_epochs',          default=30,                type=int, help='Number of post training epochs for SSDA')
    parser.add_argument('--learning_rate',       default=0.001,             type=float, help='Learning rate for training')
    parser.add_argument('--weight_decay',        default=0.0001,            type=float, help='Weight decay for optimizer')
    parser.add_argument('--unlabeled_ratio',  default=0.7,               type=float, help='Ratio of unsupervised data used in training')
    
    parser.add_argument('--mmd_weight', default=1.0, type=float, help='Weight of MMD loss for MoSSDA Training')
    parser.add_argument('--kernel_type', default='linear', type=str, help='Type of MMD kernel for MoSSDA [rbf, linear]')
    parser.add_argument('--ctr_weight', default=1.0, type=float, help='Weight of Supervised Contrastive loss for MoSSDA Training')
    parser.add_argument('--mix_type', default='cross', type=str, help='Mix type for Phase 2 training [cross, within, all, noMix]' )
    
    # arguments
    args = parser.parse_args()
    set_global_weight(args.mmd_weight, args.ctr_weight)
    set_global_kernelType(args.kernel_type)
    set_global_mixType(args.mix_type)
    
    # create trainier object
    trainer = Trainer_MoSSDA(args)

    print("+++++++++++ TRAIN PHASE +++++++++++++")
    trainer.fit()
    
    print("+++++++++++ TEST PHASE +++++++++++++")
    trainer.test()

    # train and test
    # if args.phase == 'train':
    #     print("+++++++++++ TRAIN PHASE +++++++++++++")
    #     trainer.fit()
    # elif args.phase == 'test':
    #     print("+++++++++++ TEST PHASE +++++++++++++")
    #     trainer.test()