import argparse
import os
import torch
from exp.exp_classification import Exp_Classification
from utils.print_args import print_args
import random
import numpy as np

def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Time Series OOD')

    # basic config
    parser.add_argument('--task_name', type=str, default='classification',
                        help='task name, options:[classification]')
    parser.add_argument('--is_training', type=int, default=1, help='status')
    parser.add_argument('--model', type=str, required=True, default='Transformer',
                        help='model name, options: [Transformer, GILE, FEDNet, AdaRNN, DIVERSITY]')
    parser.add_argument('--logdir', type=str, default='./runs/', help='tensorboard logs path')
    parser.add_argument('--seed', type=int, default=2023, help="random seed")
    
    
    # data loader
    parser.add_argument('--data', type=str, required=True, default='UCIHAR', help='dataset name')
    parser.add_argument('--root_path', type=str, default='../data/', help='root path of the data file')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
    parser.add_argument('--target_domain', type=int, required=False, default=0, help='domain of data')
    parser.add_argument('--balance', action='store_true', help='balance the data by loader sampler')
    parser.add_argument('--normalize', action='store_true', help='normalize data')
    parser.add_argument('--model_id', type=str, required=False, default='test', help='model id for general classification task')
    parser.add_argument('--augmentation_ratio', type=int, default=0, help="How many times to augment")

    # model define
    parser.add_argument('--n_features', type=int, default=9, help='name of feature dimension')
    parser.add_argument('--n_classes', type=int, default=6, help='name of class')
    parser.add_argument('--n_domains', type=int, default=5, help='name of domains')
    parser.add_argument('--len_seq', type=int, default=128, help='input sequence length')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--activation', type=str, default='relu', help='activation function')

    # GILE, FEDNet parameters
    parser.add_argument('--hidden_size', type=int, default=50, help='hiddlen representation dim of AutoEncoders')
    parser.add_argument('--sigma', type=float, default=1, help='parameter of mmd')
    parser.add_argument('--beta', type=float, default=1., help='multiplier for KL')
    parser.add_argument('--fc_dim', type=int, default=512, help='fc_dim')
    parser.add_argument('--beta_d', type=float, default=1., help='multiplier for KL d')
    parser.add_argument('--beta_x', type=float, default=0., help='multiplier for KL x')
    parser.add_argument('--beta_y', type=float, default=1., help='multiplier for KL y')
    parser.add_argument('--kernel_size', type=int, default=5, help='kernel_size')
    parser.add_argument('--weight_true', type=float, default=1000.0, help='weights for classifier true')
    parser.add_argument('--weight_false', type=float, default=1000.0, help='weights for classifier false')
    parser.add_argument('--aux_loss_multiplier_y', type=float, default=1000., help='multiplier for y classifier')
    parser.add_argument('--aux_loss_multiplier_d', type=float, default=1000., help='multiplier for d classifier')
    parser.add_argument('--aux_optim_learning_rate', type=float, default=1e-3, help='multiplier for x classifier')

    # FEDNet
    parser.add_argument('--alpha', type=float, default=0.2, help='ratio of time-invariant component')
    parser.add_argument('--freq_type', type=str, default='fft', help='type of frequency decomposition method')
    parser.add_argument('--temperature', type=float, default=0.2, help='temperature of contrastive loss')
    parser.add_argument('--S', type=int, default=4, help='sqeeze share heads')

    # PatchTST, FEDNet-TimestocEncoder, Autoformer
    parser.add_argument('--patch_len', type=int, default=16, help='patch length')
    parser.add_argument('--stride', type=int, default=8, help='stride')
    parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
    parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
    parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
    parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
    parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
    parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
    parser.add_argument('--factor', type=int, default=1, help='attn factor')
    parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
    parser.add_argument('--constraint_type', type=str, default="constract", help='type of constraint loss options:[cross, constract]')
    
    # dlinear
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    
    # TimesNet
    parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock')
    parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')

    # FEDformer
    parser.add_argument('--c_out', type=int, default=7, help='output size')

    # Autoformer
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--freq', type=str, default='h',
                        help='frequency of time series, options:[h, d, w, m, q, y]')
    # FreTS
    parser.add_argument('--channel_independence', type=int, default=1, help='1: channel dependence 0: channel independence')

    # optimization
    parser.add_argument('--num_workers', type=int, default=4, help='data loader num workers')
    parser.add_argument('--itr', type=int, default=1, help='experiments times')
    parser.add_argument('--epochs', type=int, default=300, help='number of training epochs')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size of train input data')
    parser.add_argument('--patience', type=int, default=1000, help='early stopping patience')
    parser.add_argument('--learning_rate', type=float, default=1e-3, help='optimizer learning rate')
    parser.add_argument('--des', type=str, default='test', help='exp description')
    parser.add_argument('--loss', type=str, default='CE', help='loss function')
    parser.add_argument('--lradj', type=str, default='keep', help='adjust learning rate')
    parser.add_argument('--lrstep', type=int, default=30, help='adjust learning rate step pace')
    parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)

    # mixer
    parser.add_argument('--mixer_kernel_size', type=int, default=8, help='patchmixer-kernel')
    
    # hyper-parameters FEDNet
    parser.add_argument('--w_det', type=float, default=1.0, help='trade-off parameter')
    parser.add_argument('--w_sto', type=float, default=1.0, help='trade-off parameter')

    # GPU
    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--gpu', type=int, default=0, help='gpu')
    parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
    parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
    parser.add_argument('--mask_rate', type=float, default=0.0, help='mask rate')

    args = parser.parse_args()
    args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False

    seed = args.seed
    seed_everything(seed)
    print("seed:", seed)

    if args.use_gpu and args.use_multi_gpu:
        args.devices = args.devices.replace(' ', '')
        device_ids = args.devices.split(',')
        args.device_ids = [int(id_) for id_ in device_ids]
        # args.gpu = args.device_ids[0]

    print('Args in experiment:')
    print_args(args)

    if args.task_name == 'classification':
        Exp = Exp_Classification
    else:
        Exp = None

    if args.is_training:
        for ii in range(args.itr):
            # setting record of experiments
            exp = Exp(args)  # set experiments
            setting = 'alpha{}_{}_{}_{}_domain_{}_freq_type_{}_{}'.format(
                args.alpha,
                args.task_name,
                args.model,
                args.data,
                args.target_domain,
                args.freq_type, 
                args.constraint_type, ii)
            if args.mask_rate > 0:
                setting += '_mask{}'.format(args.mask_rate)

            print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
            exp.train(setting)

            print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
            exp.test(setting)

            # print('>>>>>>>Adistance : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
            # dis_a = exp.calculate_a_distance(setting)
            # print("A-distance", dis_a)
            torch.cuda.empty_cache()
    else:
        ii = 0
        setting = 'alpha{}_{}_{}_{}_domain_{}_freq_type_{}_{}'.format(
                args.alpha,
                args.task_name,
                args.model,
                args.data,
                args.target_domain,
                args.freq_type,
                args.constraint_type, ii)

        if args.mask_rate > 0:
                setting += '_mask{}'.format(args.mask_rate)

        exp = Exp(args)  # set experiments
        print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
        exp.test(setting, test=1)
        torch.cuda.empty_cache()
