# coding=utf-8

import os
import sys
import time
import numpy as np
import argparse
import torch
import pywt
from alg.opt import *
from alg import alg, modelopera
from utils.util import set_random_seed, save_checkpoint, print_args, train_valid_target_eval_names, alg_loss_dict, Tee, act_param_init, print_environ
# from datautil.getdataloader import get_img_dataloader
from datautil.getdataloader_single import get_act_dataloader, get_dataloaders
import warnings
from plot_a_distance import proxy_a_distance, proxy_mlp_a_distance
warnings.filterwarnings("ignore")

def get_a_distance(args, algorithm, train_loaders_not_infinity, target_loader, device):
    device = "cpu"
    algorithm.load_state_dict(torch.load(os.path.join(args.output, 'model.pkl'))['model_dict'])
    algorithm.eval()
    algorithm = algorithm.to(device)
    source_X = []
    target_X = []
    for data in train_loaders_not_infinity:
        x, y, d, pctarget, pdtarget, index = data
        x, y, d = x.float().to(device), y.to(device), d.to(device)
        # print(algorithm.get_features(x).shape)
        # break
        features = algorithm.get_features(x)
        source_X.append(features)
    
    for data in target_loader:
        x, y, d, pctarget, pdtarget, index = data
        x, y, d = x.float().to(device), y.to(device), d.to(device)
        features = algorithm.get_features(x)
        target_X.append(features)

    source_X = torch.cat(source_X, dim=0).cpu().detach().numpy()
    target_X = torch.cat(target_X, dim=0).cpu().detach().numpy()
    print(source_X.shape, target_X.shape)

    # result = proxy_a_distance(source_X, target_X, verbose=True)
    result = proxy_mlp_a_distance(source_X, target_X, verbose=True)
    print("A-distance:", result)
    return result

def get_args():
    parser = argparse.ArgumentParser(description='DG')
    parser.add_argument('--algorithm', type=str, default="ERM")
    parser.add_argument('--alpha', type=float,
                        default=1, help='DANN dis alpha')
    parser.add_argument('--anneal_iters', type=int,
                        default=1000, help='Penalty anneal iters used in VREx')
    parser.add_argument('--batch_size', type=int,
                        default=64, help='batch_size')
    parser.add_argument('--beta', type=float,
                        default=1, help='DIFEX beta')
    parser.add_argument('--beta1', type=float, default=0.5,
                        help='Adam hyper-param')
    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--checkpoint_freq', type=int,
                        default=3, help='Checkpoint every N epoch')
    parser.add_argument('--classifier', type=str,
                        default="linear", choices=["linear", "wn"])
    parser.add_argument('--data_file', type=str, default='',
                        help='root_dir')
    parser.add_argument('--dataset', type=str, default='office')
    parser.add_argument('--data_dir', type=str, default='', help='data dir')
    parser.add_argument('--dis_hidden', type=int,
                        default=256, help='dis hidden dimension')
    parser.add_argument('--disttype', type=str, default='2-norm',
                        choices=['1-norm', '2-norm', 'cos', 'norm-2-norm', 'norm-1-norm'])
    parser.add_argument('--gpu_id', type=str, nargs='?',
                        default='0', help="device id to run")
    parser.add_argument('--groupdro_eta', type=float,
                        default=1e-2, help="groupdro eta")
    parser.add_argument('--inner_lr', type=float,
                        default=1e-2, help="learning rate used in MLDG")
    parser.add_argument('--lam', type=float,
                        default=1, help="tradeoff hyperparameter used in VREx")
    parser.add_argument('--layer', type=str, default="bn",
                        choices=["ori", "bn"])
    parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--lr_decay', type=float, default=0.8, help='for sgd')
    parser.add_argument('--lr_decay1', type=float,
                       default=1.0, help='for pretrained featurizer')
    parser.add_argument('--lr_decay2', type=float, default=1.0,
                        help='inital learning rate decay of network')
    parser.add_argument('--lr_gamma', type=float,
                        default=0.0003, help='for optimizer')
    parser.add_argument('--max_epoch', type=int,
                        default=120, help="max iterations")
    parser.add_argument('--mixupalpha', type=float,
                        default=0.2, help='mixup hyper-param')
    parser.add_argument('--mldg_beta', type=float,
                        default=1, help="mldg hyper-param")
    parser.add_argument('--mmd_gamma', type=float,
                        default=1, help='MMD, CORAL hyper-param')
    parser.add_argument('--momentum', type=float,
                        default=0.9, help='for optimizer')
    parser.add_argument('--net', type=str, default='resnet50',
                        help="featurizer: vgg16, resnet50, resnet101,DTNBase")
    parser.add_argument('--N_WORKERS', type=int, default=4)
    parser.add_argument('--rsc_f_drop_factor', type=float,
                        default=1/3, help='rsc hyper-param')
    parser.add_argument('--rsc_b_drop_factor', type=float,
                        default=1/3, help='rsc hyper-param')
    parser.add_argument('--save_model_every_checkpoint', action='store_true')
    parser.add_argument('--schuse', action='store_true')
    parser.add_argument('--schusech', type=str, default='cos')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--split_style', type=str, default='strat',
                        help="the style to split the train and eval datasets")
    parser.add_argument('--task', type=str, default="cross_people",
                        choices=["cross_people"], help='now only support cross_people')
    parser.add_argument('--tau', type=float, default=1, help="andmask tau")
    parser.add_argument('--test_envs', type=int, nargs='+',
                        default=[0], help='target domains')
    parser.add_argument('--output', type=str,
                        default="train_output", help='result output path')
    parser.add_argument('--weight_decay', type=float, default=5e-4)

    # for adarnn
    parser.add_argument('--hidden_size', type=int, default=64)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--pre_epoch', type=int, default=40)  # 20, 30, 50
    parser.add_argument('--dw', type=float, default=0.5) # 0.01, 0.05, 5.0
    parser.add_argument('--trans_loss', type=str, default='adv')
    parser.add_argument('--len_win', type=int, default=0)
    parser.add_argument('--model_type', type=str, default='AdaRNN')

    # for irm + freq
    parser.add_argument('--use_freq', action="store_true")
    parser.add_argument('--freq_alpha', type=float, default=0.2)
    parser.add_argument('--freq_type', type=str, default='fft')
    parser.add_argument('--penalty_weight', type=int, default=1)
    parser.add_argument('--ib_lambda', type=float, default=10)
    parser.add_argument('--ib_penalty_anneal_iters', type=int, default=500)
    parser.add_argument('--enable_bn', action="store_true")
    parser.add_argument('--nonlinear_classifier', action="store_false")
    parser.add_argument('--lambda_beta', type=int, default=1e-3)
    parser.add_argument('--lambda_inv_risks', type=int, default=1)

    args = parser.parse_args()
    args.steps_per_epoch = 10
    args.data_dir = args.data_file+args.data_dir
    
    out_dir_name = "{}_domain{}_{}_seed{}_lr{}_epochs{}_use_freq_{}_freq_alpha_{}".format(args.dataset, args.test_envs[0], args.algorithm, args.seed, args.lr, args.max_epoch, args.use_freq, args.freq_alpha)
    args.output = os.path.join(args.output, out_dir_name)
    os.makedirs(args.output, exist_ok=True)
    sys.stdout = Tee(os.path.join(args.output, 'out.txt'))
    sys.stderr = Tee(os.path.join(args.output, 'err.txt'))
    args = act_param_init(args)
    print_environ()
    return args

def get_mask_spectrum(train_loader, freq_alpha = 0.2, freq_type="fft"):
    """
    get shared frequency spectrums
    """
    amps = 0.0
    for data in train_loader:
        lookback_window = data[0].squeeze(dim=2).permute(0,2,1)
        B, L, C = lookback_window.shape
        frequency_feature = None
        if freq_type == "fft":
            frequency_feature = torch.fft.rfft(lookback_window, dim=1)
        elif freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
            wavelet = pywt.Wavelet(freq_type)
            # print("ortho=", wavelet.orthogonal)
            lookback_window = lookback_window.permute(0,2,1)
            device = lookback_window.device
            X = lookback_window.numpy()
            cA, cD = pywt.dwt(X, wavelet)
            frequency_feature = np.concatenate((cA, cD), axis=2).transpose((0,2,1)) # B D C
            frequency_feature = torch.from_numpy(frequency_feature).to(device)

        assert frequency_feature != None
        # print("fre:", frequency_feature.shape)
        # print(abs(frequency_feature).mean(dim=0).shape)
        # print(abs(frequency_feature).mean(dim=0).mean(dim=1).shape) B D C
        amps += abs(frequency_feature).mean(dim=0).mean(dim=1)

    mask_spectrum = amps.topk(int(amps.shape[0]*freq_alpha)).indices
    print("mask_spectrum:", mask_spectrum)
    return mask_spectrum # as the spectrums of time-invariant component

if __name__ == '__main__':
    args = get_args()
    set_random_seed(args.seed)
    # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    torch.cuda.set_device("cuda:"+args.gpu_id)
    device = torch.device("cuda:{}".format(args.gpu_id))

    loss_list = alg_loss_dict(args)
    train_loaders, target_loaders, eval_loaders = get_dataloaders(args)    
    train_loader, train_loaders_not_infinity, valid_loader, target_loader, _, _, _ = get_act_dataloader(args) # only get args.domain_num
    
    # add frequency spectrum mask for IRM
    if args.use_freq:
        mask_spectrum = get_mask_spectrum(train_loaders_not_infinity, freq_alpha = args.freq_alpha, freq_type=args.freq_type)
        args.mask_spectrum = mask_spectrum

    eval_name_dict = train_valid_target_eval_names(args)
    algorithm_class = alg.get_algorithm_class(args.algorithm)
    algorithm = algorithm_class(args).to(device)
    algorithm.train()
    opt = get_optimizer_DGModel(algorithm, args)
    sch = get_scheduler_DGModel(opt, args)

    s = print_args(args, [])
    print('=======hyper-parameter used========')
    print(s)

    if hasattr(algorithm, 'device'):
        print("algorithm device:", algorithm.device)
    else:
        print("set device", device)
        algorithm.device = device

    if 'DIFEX' in args.algorithm:
        ms = time.time()
        n_steps = args.max_epoch*args.steps_per_epoch
        print('start training fft teacher net')
        opt1 = get_optimizer_DGModel(algorithm.teaNet, args, isteacher=True)
        sch1 = get_scheduler_DGModel(opt1, args)
        algorithm.teanettrain(train_loaders, n_steps, opt1, sch1)
        print('complet time:%.4f' % (time.time()-ms))

    acc_record = {}
    acc_type_list = ['train', 'valid', 'target']
    train_minibatches_iterator = zip(*train_loaders)
    best_valid_acc, target_acc = 0, 0
    speed_list = []
    itercnt = 0

    step_vals = []
    print('===========start training===========')
    sss = time.time()
    for epoch in range(args.max_epoch):
        for iter_num in range(args.steps_per_epoch):
            minibatches_device = [(data)
                                  for data in next(train_minibatches_iterator)]
            if args.algorithm in ('VREx', 'IRM', 'IB_IRM') and algorithm.update_count == args.anneal_iters:
                print("Reset optimizer, because it doesn't like the sharp jump in gradient")
                opt = get_optimizer_DGModel(algorithm, args)
                sch = get_scheduler_DGModel(opt, args)
            step_vals = algorithm.update(minibatches_device, opt, sch)
            itercnt += 1
            if (itercnt + 1) % 100 == 0 or itercnt == 0:
                speed = (time.time() - sss) * 1000 / 100
                sss = time.time()
                speed_list.append(speed)

        if (epoch in [int(args.max_epoch*0.7), int(args.max_epoch*0.9)]) and (not args.schuse):
            print('manually descrease lr')
            for params in opt.param_groups:
                params['lr'] = params['lr']*0.1

        if (epoch == (args.max_epoch-1)) or (epoch % args.checkpoint_freq == 0):
            print('===========epoch %d===========' % (epoch))
            s = ''
            for item in loss_list:
                s += (item+'_loss:%.4f,' % step_vals[item])
            print(s[:-1])
            s = ''
            # eval and test are the same
            if eval_loaders == None:
                acc_record["valid"], _ = modelopera.accuracy(algorithm, target_loaders[0], None)
                acc_record["target"], metic = modelopera.accuracy(algorithm, target_loaders[0], None)
            else:
                valid_acc_list = []
                for item in eval_name_dict["valid"]:
                    valid_acc, _ = modelopera.accuracy(algorithm, eval_loaders[item], None)
                    valid_acc_list.append(valid_acc)
                    s += (item+'_acc:%.4f,' % valid_acc)
                acc_record["valid"] = np.mean(np.array(valid_acc_list))
                acc_record["target"], metic = modelopera.accuracy(algorithm, target_loaders[0], None)
            s += ('valid_acc:%.4f,' % acc_record["valid"])
            s += ('target_acc:%.4f,' % acc_record["target"])
            print(s[:-1])
            if acc_record['valid'] > best_valid_acc:
                best_valid_acc = acc_record['valid']
                target_acc = acc_record['target']
                # results['target_acc'] = acc
                best_model_precision = metic[0]
                best_model_recall = metic[1]
                best_model_f1 = metic[2]
            # if args.save_model_every_checkpoint:
            #     save_checkpoint(f'model_epoch{epoch}.pkl', algorithm, args)
            print('total cost time: %.4f' % (time.time()-sss))
            algorithm_dict = algorithm.state_dict()

    save_checkpoint('model.pkl', algorithm, args)
    print('valid acc: %.4f' % best_valid_acc)
    print('DG result: %.4f' % target_acc)
    print('precision: %.4f' % best_model_precision)
    print('recall: %.4f' % best_model_recall)
    print('f1: %.4f' % best_model_f1)
    print("mean speed: {:.2f}ms/iter".format(sum(speed_list) / len(speed_list)))
    a = get_a_distance(args, algorithm, train_loader, target_loader, device)
    with open(os.path.join(args.output, 'done.txt'), 'w') as f:
        f.write('done\n')
        f.write('total cost time:%s\n' % (str(time.time()-sss)))
        f.write('valid acc:%.4f\n' % (best_valid_acc))
        f.write('target acc:%.4f' % (target_acc))
        f.write('precision:%.4f\n' % (best_model_precision))
        f.write('recall:%.4f\n' % (best_model_recall))
        f.write('f1:%.4f\n' % (best_model_f1))
        f.write("mean speed: {:.2f}ms/iter\n".format(sum(speed_list) / len(speed_list)))
        f.write("A-distance: {:.4f}".format(a))
