# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


import random
import numpy as np
import torch
import sys
import os
import argparse
import torchvision
import PIL


def set_random_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def train_valid_target_eval_names(args):
    eval_name_dict = {'train': [], 'valid': [], 'target': []}
    for i in range(args.domain_num):
        if i not in args.test_envs:
            eval_name_dict['train'].append('eval%d_in' % i)
            eval_name_dict['valid'].append('eval%d_out' % i)
        else:
            eval_name_dict['target'].append('eval%d_out' % i)
    return eval_name_dict


def alg_loss_dict(args):
    loss_dict = {
        'ANDMask': ['total'],
        'CORAL': ['class', 'coral', 'total'],
        'DANN': ['class', 'dis', 'total'],
        'ERM': ['class'],
        'Mixup': ['class'],
        'MLDG': ['total'],
        'MMD': ['class', 'mmd', 'total'],
        'GroupDRO': ['group'],
        'RSC': ['class'],
        'VREx': ['loss', 'nll', 'penalty'],
        'DIFEX': ['class', 'dist', 'exp', 'align', 'total'],
        'Diversify': ['class', 'dis', 'total'],
        'AdaRNN': ['total','class','trans'],
        'IRM': ['loss', 'nll', 'penalty'],
        'IIB': ['loss', 'inv', 'env', 'ib'],
        'IB_IRM': ['loss', 'domain', 'irm_penalty', 'ib'],
    }
    return loss_dict[args.algorithm]


def print_args(args, print_list):
    s = "==========================================\n"
    l = len(print_list)
    for arg, content in args.__dict__.items():
        if l == 0 or arg in print_list:
            s += "{}:{}\n".format(arg, content)
    return s


def print_row(row, colwidth=10, latex=False):
    if latex:
        sep = " & "
        end_ = "\\\\"
    else:
        sep = "  "
        end_ = ""

    def format_val(x):
        if np.issubdtype(type(x), np.floating):
            x = "{:.10f}".format(x)
        return str(x).ljust(colwidth)[:colwidth]
    print(sep.join([format_val(x) for x in row]), end_)


def print_environ():
    print("Environment:")
    print("\tPython: {}".format(sys.version.split(" ")[0]))
    print("\tPyTorch: {}".format(torch.__version__))
    print("\tTorchvision: {}".format(torchvision.__version__))
    print("\tCUDA: {}".format(torch.version.cuda))
    print("\tCUDNN: {}".format(torch.backends.cudnn.version()))
    print("\tNumPy: {}".format(np.__version__))
    print("\tPIL: {}".format(PIL.__version__))


class Tee:
    def __init__(self, fname, mode="a"):
        self.stdout = sys.stdout
        self.file = open(fname, mode)

    def write(self, message):
        self.stdout.write(message)
        self.file.write(message)
        self.flush()

    def flush(self):
        self.stdout.flush()
        self.file.flush()

# emg 8,1,200 6 10
# DSADS 45,1,125 19 10
# PAMAP 27,1,200 8  10
# USCHAD 6,1,200 12 10
def act_param_init(args):
    args.select_position = {'EMG': [0], 'DSADS': [0], 'PAMAP':[0],'USCHAD':[0], 'UCIHAR':[0], 'SHAR': [0], 'OPP': [0],
                            'PCL': [0], 'HHAR': [0], 'Spurious_Fourier': [0], 'WESAD': [0], 'EEG': [0]}
    args.select_channel = {'EMG': np.arange(8),
                           'DSADS': np.arange(45), 
                           'PAMAP':np.arange(27), 
                           'USCHAD':np.arange(6),
                           'UCIHAR': np.arange(6),
                           'SHAR': np.arange(3),
                           'OPP': np.arange(77),
                           'PCL': np.arange(48),
                           'HHAR': np.arange(6),
                           'Spurious_Fourier': np.arange(2),
                           'WESAD': np.arange(8),
                           'EEG': np.arange(1)}
    args.hz_list = {'EMG': 1000,'DSADS':1000,'PAMAP':1000,'USCHAD':1000, 'UCIHAR':1000, 'SHAR':1000, 'OPP':1000,
                    'PCL': 1000, 'HHAR': 1000, 'Spurious_Fourier': 1000, 'WESAD': 1000, 'EEG': 1000}
    args.act_people = {'EMG': [[i*9+j for j in range(9)]for i in range(4)],
                       'DSADS': [[0,1],[2,3],[4,5],[6,7]],
                       'PAMAP':[[2,3,8],[1,5],[0,7],[4,6]],
                       'USCHAD':[[0,1,2,11],[3,5,6,9],[7,8,10,13],[4,12]],
                       'UCIHAR':[[],[],[],[],[]],
                       'SHAR':[[],[],[],[]],
                       'OPP':[[],[],[],[]],
                       'PCL': [[],[],[]],
                       'HHAR': [[],[],[],[],[]],
                       'Spurious_Fourier': [[], []],
                       'WESAD': [[0,1,2,3],[4,5,6,7],[8,9,10,11],[12,13,14]],
                       'EEG': [[0,1,2,3,4],[5,6,7,8,9],[10,11,12,13,14],[15,16,17,18,19]]}
    
    tmp = {'EMG': ((8, 1, 200), 6, 10),
           'DSADS': ((45,1,125),19,10),
           'PAMAP': ((27,1,200),18,10),
           'USCHAD': ((6,1,200),12,10),
           'UCIHAR': ((9,1,128),6,10),
           'SHAR': ((3,1,151),17,10),
           'OPP': ((77,1,30),18,10),
           'PCL': ((48,1,750), 2, 10),
           'HHAR': ((6,1,500), 6, 10),
           'Spurious_Fourier': ((1,1,50), 2, 10),
           'WESAD': ((8,1,200), 4, 10),
           'EEG': ((1,1,3000), 5, 10)}
    args.num_classes, args.input_shape, args.grid_size = tmp[
        args.dataset][1], tmp[args.dataset][0], tmp[args.dataset][2]

    return args

def save_checkpoint(filename, alg, args):
    save_dict = {
        "args": vars(args),
        "model_dict": alg.cpu().state_dict()
    }
    torch.save(save_dict, os.path.join(args.output, filename))

def get_args():
    parser = argparse.ArgumentParser(description='DG')
    parser.add_argument('--algorithm', type=str, default="diversify")
    parser.add_argument('--alpha', type=float,
                        default=0.1, help="DANN dis alpha")
    parser.add_argument('--alpha1', type=float,
                        default=0.1, help="DANN dis alpha")
    parser.add_argument('--batch_size', type=int,
                        default=256, help="batch_size")
    parser.add_argument('--beta1', type=float, default=0.5, help="Adam")
    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--checkpoint_freq', type=int,
                        default=100, help='Checkpoint every N steps')
    parser.add_argument('--classifier', type=str,
                        default="linear", choices=["linear", "wn"])
    parser.add_argument('--data_file', type=str, default='')
    parser.add_argument('--dataset', type=str, default='DSADS')
    parser.add_argument('--data_dir', type=str, default='')
    parser.add_argument('--dis_hidden', type=int, default=256)
    parser.add_argument('--gpu_id', type=str, nargs='?',
                        default='0', help="device id to run")
    parser.add_argument('--layer', type=str, default="bn",
                        choices=["ori", "bn"])
    parser.add_argument('--lam', type=float, default=0.0)
    parser.add_argument('--latent_domain_num', type=int, default=3)
    parser.add_argument('--local_epoch', type=int,
                        default=1, help='local iterations')
    parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--lr_decay1', type=float,
                        default=1.0, help='for pretrained featurizer')
    parser.add_argument('--lr_decay2', type=float, default=1.0)
    parser.add_argument('--max_epoch', type=int,
                        default=120, help="max iterations")
    parser.add_argument('--model_size', default='median',
                        choices=['small', 'median', 'large', 'transformer'])
    parser.add_argument('--N_WORKERS', type=int, default=4)
    parser.add_argument('--old', action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--task', type=str, default="cross_people")
    parser.add_argument('--test_envs', type=int, nargs='+', default=[0])
    parser.add_argument('--output', type=str, default="train_output")
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--use_freq', action="store_true")
    args = parser.parse_args()
    args.steps_per_epoch = 10000000000
    args.data_dir = args.data_file+args.data_dir
    # os.environ['CUDA_VISIBLE_DEVICS'] = args.gpu_id
    out_dir_name = "{}_domain{}_{}_seed{}_lr{}_epochs{}_latent_domain_num{}".format(args.dataset, args.test_envs[0], args.algorithm, args.seed, args.lr, args.max_epoch, args.latent_domain_num)
    args.output = os.path.join(args.output, out_dir_name)
    os.makedirs(args.output, exist_ok=True)
    sys.stdout = Tee(os.path.join(args.output, 'out.txt'))
    sys.stderr = Tee(os.path.join(args.output, 'err.txt'))
    args = act_param_init(args)
    return args
