#coding=utf-8
from argparse import Action
from torch.nn.modules import module
import datautil.actdata.cross_dataset as cross_dataset
import datautil.actdata.cross_people as cross_people
import datautil.actdata.cross_position as cross_position
import random
import numpy as np
import torch
import sys
import os
import argparse
import torchvision
import PIL
import types
import json

class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj,types.ModuleType):
            return int(-1)
        else:
            return super(MyEncoder, self).default(obj)

def set_random_seed(seed=0):
    # seed setting
    # os.environ['PYTHONHASHSEED']=str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def save_checkpoint(filename,alg,args,opt,sch,opt1=None,sch1=None):
    save_dict = {
        "args": vars(args),
        "model_dict": alg.cpu().state_dict(),
        'opt':opt.state_dict(),
    }
    if sch is not None:
        save_dict['epo']=sch.last_epoch
    if opt1 is not None:
        save_dict['opt1']=opt1.state_dict()
        if sch1 is not None:
            save_dict['epo1']=sch1.last_epoch
    torch.save(save_dict, os.path.join(args.output, filename))    

def has_been_trained(root_dir,file_suff=''):
    if (not os.path.exists(root_dir+'/newdone')) and (not os.path.exists(root_dir+'/newdone.txt')) and (not os.path.exists(root_dir+'/newdone'+file_suff)):
        return False
    try:
        with open(root_dir+'/newdone','r') as f:
            s=f.readlines()
    except Exception as e:
        try:
            with open(root_dir+'/newdone.txt','r') as f:
                s=f.readlines()
        except Exception as e:
            with open(root_dir+'/newdone'+file_suff,'r') as f:
                s=f.readlines()              
    if 'done' in s[0]:
        return True
    else:
        return False

def train_valid_target_eval_names(args):
    eval_name_dict={'train':[],'valid':[],'target':[]}
    for i in range(args.domain_num):
        if i not in args.test_envs:
            eval_name_dict['train'].append('eval%d_in'%i)
            eval_name_dict['valid'].append('eval%d_out'%i)
        else:
            eval_name_dict['target'].append('eval%d_out'%i)
    return eval_name_dict

def alg_loss_dict(args):
    loss_dict={
                'TDB':['class','dis','total'],
                'TDBself':['class','dis','total']}
    return loss_dict[args.algorithm]

def print_args(args,print_list):
    s = "==========================================\n"
    l=len(print_list)
    for arg, content in args.__dict__.items():
        if l==0 or arg in print_list:
            s += "{}:{}\n".format(arg, content)
    return s

def print_row(row, colwidth=10, latex=False):
    if latex:
        sep = " & "
        end_ = "\\\\"
    else:
        sep = "  "
        end_ = ""

    def format_val(x):
        if np.issubdtype(type(x), np.floating):
            x = "{:.10f}".format(x)
        return str(x).ljust(colwidth)[:colwidth]
    print(sep.join([format_val(x) for x in row]), end_)

def print_environ():
    print("Environment:")
    print("\tPython: {}".format(sys.version.split(" ")[0]))
    print("\tPyTorch: {}".format(torch.__version__))
    print("\tTorchvision: {}".format(torchvision.__version__))
    print("\tCUDA: {}".format(torch.version.cuda))
    print("\tCUDNN: {}".format(torch.backends.cudnn.version()))
    print("\tNumPy: {}".format(np.__version__))
    print("\tPIL: {}".format(PIL.__version__))

class Tee:
    def __init__(self, fname, mode="a"):
        self.stdout = sys.stdout
        self.file = open(fname, mode)

    def write(self, message):
        self.stdout.write(message)
        self.file.write(message)
        self.flush()

    def flush(self):
        self.stdout.flush()
        self.file.flush()

def act_param_init(args):
    args.select_position={'dsads':[0],'usc':[0],'har':[0],'pamap':[1],'emg':[0],'spcmd':[0],'wesad':[0]}
    args.select_channel={'dsads':np.arange(6),'usc':np.arange(6),'har':np.arange(6),'pamap':np.arange(6),'emg':np.arange(8),'spcmd':np.arange(20),'wesad':np.arange(8)}
    args.label_cor={'dsads':[[0],[1],[2,3],[4],[5],[8]],'usc':[[7],[8],[9],[3],[4],[0]],'har':[[3],[4],[5],[1],[2],[0]],'pamap':[[1],[2],[0],[7],[8],[3]]}
    args.hz_list={'dsads':25,'usc':100,'har':50,'pamap':100,'emg':1000,'spcmd':81,'wesad':33}
    args.act_dataset=['dsads','usc','har','pamap']
    args.act_people={'dsads':[[i*2,i*2+1] for i in range(4)],'usc':[[1,11,2,0],[6,3,9,5],[7,13,8,10],[4,12]],'har':[[i*6+j for j in range(6)] for i in range(5)],
        'pamap':[[3,2,8],[1,5],[0,7],[4,6]],'shar':[[0],[1],[2],[3]],'emg':[[i*9+j for j in range(9)]for i in range(4)],'spcmd':[[0]],
        'wesad':[[0,1,2,3],[4,5,6,7],[8,9,10,11],[12,13,14]]}
    args.act_positon={'dsads':[[i] for i in range(5)],'usc':[[i] for i in range(1)],'har':[[i] for i in range(1)],'pamap':[[i] for i in range(3)]}
    if args.task=='cross_dataset':
        args.num_classes=6
        args.input_shape=(6,1,50)
        args.grid_size=5
    else:
        if args.task=='cross_people':
            tmp={'dsads':((45,1,125),19,5),'usc':((6,1,200),12,10),'har':((6,1,128),6,8),'pamap':((27,1,200),12,10),'shar':((3,1,151),17,10),'emg':((8,1,200),6,10),'spcmd':((20,1,81),10,9),'wesad':((8,1,200),4,10)}
        elif args.task=='cross_position':
            tmp={'dsads':((9,1,125),19,5),'usc':((6,1,200),12,10),'har':((6,1,128),6,8),'pamap':((9,1,200),12,10)}
        args.num_classes,args.input_shape,args.grid_size=tmp[args.dataset][1],tmp[args.dataset][0],tmp[args.dataset][2]

    return args

def get_args():
    parser = argparse.ArgumentParser(description='DG')
    parser.add_argument('--algorithm', type=str, default="ERM")
    parser.add_argument('--alpha', type=float, default=0.1, help="DANN dis alpha")
    parser.add_argument('--alpha1', type=float, default=0.1, help="DANN dis alpha")
    parser.add_argument('--allnettype', type=str, default="all", choices=["all", "domain","class"])
    parser.add_argument('--anneal_iters', type=int,
                        default=500, help='Penalty anneal iters used in VREx')
    parser.add_argument('--batch_size', type=int, default=32, help="batch_size")
    parser.add_argument('--beta', type=float, default=6, help="icassph")
    parser.add_argument('--beta1', type=float, default=0.5, help="Adam")
    parser.add_argument('--bottleneck', type=int, default=256)
    parser.add_argument('--checkpoint_freq', type=int, default=100,help='Checkpoint every N steps')
    # parser.add_argument('--checkwhether_fininshed',action='store_true')
    parser.add_argument('--classifier', type=str, default="linear", choices=["linear", "wn"])
    parser.add_argument('--class_balanced', type=int, default=0)
    parser.add_argument('--cls_par', type=float, default=1.0)
    parser.add_argument('--confuse',action='store_true')
    parser.add_argument('--data_file', type=str, default='')
    parser.add_argument('--dataset', type=str, default='dsads')
    parser.add_argument('--data_dir', type=str, default='')
    parser.add_argument('--dis_hidden',type=int,default=256)
    parser.add_argument('--disttype',type=str,default='2-norm',choices=['1-norm','2-norm','cos','norm-2-norm','norm-1-norm'])
    parser.add_argument('--disttype1',type=str,default='2-norm',choices=['1-norm','2-norm','cos','norm-2-norm','norm-1-norm'])
    parser.add_argument('--diversity_alpha1',type=float,default=1.0)
    parser.add_argument('--diversity_alpha2',type=float,default=1.0)
    parser.add_argument('--diversity_alpha3',type=float,default=1.0)
    parser.add_argument('--domain_selected', type=int, nargs='+', default=[0,1])
    parser.add_argument('--earlystopcount',type=int,default=30)
    parser.add_argument('--enhanced_simclr_transform',action='store_true')
    parser.add_argument('--entropy',action='store_true')
    parser.add_argument('--eposel',type=str,default='sim',choices=['sim','spe'])
    parser.add_argument('--free_param',action='store_true')
    parser.add_argument('--getwei',type=str,default='whole',choices=['whole','batch'])
    parser.add_argument('--getparam',type=str,default='fea',choices=['whole','fea'])
    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--grid_size',type=int,default=3)
    parser.add_argument('--groupdro_eta', type=float, default=1, help="groupdro eta")
    parser.add_argument('--gstyle', type=str, default="avgpool", choices=["avgpool", "linear",'cnn'])
    parser.add_argument('--layer', type=str, default="bn", choices=["ori", "bn"])
    parser.add_argument('--inner_lr', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--lam',type=float,default=0.001,help='for dlda, VREx')
    parser.add_argument('--latent_domain_num',type=int,default=3)
    parser.add_argument('--local_epoch',type=int,default=1,help='local iterations')
    parser.add_argument('--lr', type=float, default=1e-2, help="learning rate")
    parser.add_argument('--lr_decay', type=float, default=0.75,help='for sgd')
    parser.add_argument('--lr_decay_step', type=float, default=80,help='for sgd')
    parser.add_argument('--lr_decay_stepd', type=float, default=50,help='for sgd')
    parser.add_argument('--lr_decay1', type=float, default=1.0,help='for pretrained featurizer')
    parser.add_argument('--lr_decay2', type=float, default=1.0)
    parser.add_argument('--lr_gamma', type=float, default=0.0003)
    parser.add_argument('--max_epoch', type=int, default=120, help="max iterations")
    parser.add_argument('--mixtype',type=str,default='singley',choices=['onehot','singley'])
    parser.add_argument('--mixupalpha', type=float, default=0.2)
    parser.add_argument('--mixup_ld_margin', type=float, default=10000)
    parser.add_argument('--mixup_lsoftmax_margin', type=int, default=2)
    parser.add_argument('--mixupregtype',type=str,default='l-softmax',choices=['l-softmax','ld-margin','l-smooth','l-softmax+l-smooth','origin'])
    parser.add_argument('--mldg_beta', type=float, default=1, help="mldg beta, 0.1, 1, 10")
    parser.add_argument('--mmd_gamma', type=float, default=1)
    parser.add_argument('--model_size',default = 'median',choices=['small','median','large'])
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101,DTNBase,ActNetwork")
    parser.add_argument('--nogen', action='store_true')
    parser.add_argument('--nomix', action='store_true')
    parser.add_argument('--normstyle',type=str,default='avg',choices=['max','avg'])
    parser.add_argument('--N_WORKERS', type=int, default=4)
    parser.add_argument('--old', action='store_true')
    parser.add_argument('--pdltype', type=str, default="random",
        choices=['random','order'])
    parser.add_argument('--percent', type=float, default=0.1,help='data preserverd, 0.-1.')
    parser.add_argument('--predata', action='store_true')
    parser.add_argument('--rela_gamma', type=float, default=1)
    parser.add_argument('--relaxed', action='store_true')
    parser.add_argument('--rsc_f_drop_factor', type=float, default=1/3)
    parser.add_argument('--rsc_b_drop_factor', type=float, default=1/3)
    parser.add_argument('--r_update_feq', type=int, default=1)
    parser.add_argument('--save_model_every_checkpoint', action='store_true')
    parser.add_argument('--save_to_blob', action='store_true')
    parser.add_argument('--schuse',action='store_true')
    parser.add_argument('--schusech',type=str,default='cos')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--seed1', type=int, default=0)
    parser.add_argument('--simclr_output_dim',type=int,default=256)
    parser.add_argument('--splitalpha', type=float, default=1.0)
    parser.add_argument('--splitrever', action='store_true')
    parser.add_argument('--split_style', type=str, default='strat', help="the style to split the train and eval datasets")
    parser.add_argument('--task', type=str, default="cross_dataset",
        choices=["img_dg",'cross_dataset','cross_people','cross_position'])
    parser.add_argument('--tau', type=float, default=1, help="andmask tau")
    parser.add_argument('--test_envs', type=int, nargs='+', default=[0])
    parser.add_argument('--top_k', type=int, default=1)
    parser.add_argument('--trainvalidsplit', type=str, default="SPRandom")
    parser.add_argument('--trainvalidseed', type=int, default=0)
    parser.add_argument('--output', type=str, default="train_output")
    parser.add_argument('--valid',action='store_true')
    parser.add_argument('--valid_size', type=float, default=0.2)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--wtype',type=str,default='ori',choices=['ori','abs','fea'])
    args = parser.parse_args()
    args.steps_per_epoch=10000000000
    if args.task.startswith('cross'):
        args.data_dir=args.data_file+args.data_dir
    elif args.task.startswith('reg'):
        pass
    else:
        args.data_dir=args.data_file+args.data_dir+args.dataset+'/'
    args.blob_output=args.output.split('/')[-1]
    os.environ['CUDA_VISIBLE_DEVICS']=args.gpu_id
    os.makedirs(args.output,exist_ok=True)
    sys.stdout = Tee(os.path.join(args.output, 'out.txt'))
    sys.stderr = Tee(os.path.join(args.output, 'err.txt'))
    if args.task.startswith('cross'):
        args=act_param_init(args)
    return args