import argparse
import torch
import os
import numpy as np
from modelMetaDA import SDSModel
from utilsMetaDA import record_acc


train_arg_parser = argparse.ArgumentParser()
train_arg_parser.add_argument("--i", type=int, default=1, 
                              help='The index of train')
train_arg_parser.add_argument("--lr", type=float, default=0.0005,
                              help='learning rate of phi, theta, default is 0.0005')
train_arg_parser.add_argument("--weight_decay", type=float, default=0.0001)
train_arg_parser.add_argument("--seed", type=int, default=-1,
                            help='the seed of random, if not set, it will be a random number')
train_arg_parser.add_argument("--total_log_dir", type=str, default=os.path.join('logs','SDS_model'),
                            help='log_dir, default is logs/SDS_model')
train_arg_parser.add_argument("--weight_decay_omega", type=float, default=0.001)
train_arg_parser.add_argument("--batch_size", type=int, default=1024,
                                help='batch size for meta learning, default is 1024')
train_arg_parser.add_argument("--ite_find_omega", type=int, default=10000)
train_arg_parser.add_argument("--max_ite_fo", type=int, default=2500)
train_arg_parser.add_argument("--ite_omega_mt", type=int, default=2)
train_arg_parser.add_argument("--ite_opt_theta_phi", type=int, default=150)
train_arg_parser.add_argument("--total_iteration", type=int, default=20)
train_arg_parser.add_argument("--save", type=int, default=0)
train_arg_parser.add_argument("--test_num", type=int, default=15)
train_arg_parser.add_argument("--lr_maml", type=float, default=0.0002)
train_arg_parser.add_argument("--lr_maml_test", type=float, default=0.0002)
train_arg_parser.add_argument("--lr_omega", type=float, default=0.0002)
train_arg_parser.add_argument("--ite_test", type=int, default=10)
train_arg_parser.add_argument("--meta_train_len", type=int, default=7)
train_arg_parser.add_argument("--pr_coupled", type=float, default=0.1)
train_arg_parser.add_argument("--name", type=str, default='unclassified')
train_arg_parser.add_argument("--dadg", type=str, default='dg')
train_arg_parser.add_argument("--test_domain", type=int, default=-1)
args = train_arg_parser.parse_args()
# train_arg_parser.add_argument("--log_path", type=str, default=os.path.join(args.total_log_dir, str(args.i)))
train_arg_parser.add_argument("--log_path", type=str, default= os.path.join(args.total_log_dir, args.name))

train_arg_parser.add_argument("--cuda", type=int, default=0, help='cuda_index')

args = train_arg_parser.parse_args()
if args.seed == -1 :
    args.seed = np.random.randint(2 ** 31 - 1)
flags = args


flags.i = 0

if flags.test_domain < 0:
    test_list = range(flags.test_num)
else:
    test_list = (flags.test_domain,)

for _ in range(flags.total_iteration):
    flags.i = _
    acc_list = []
    seed_list = []
    for num in test_list:
        if _ >= 1:
            flags.seed = np.random.randint(2 ** 31 - 1) # use different seed
        flags.num = num
        A = SDSModel(flags)
        acc = A.my_train(flags)
        acc_list.append(acc)
        seed_list.append(flags.seed)

    record_acc(os.path.join(flags.log_path, 'total_result.csv'),
                '{}'.format(_),
                '\n{}'.format(acc_list),
                '\n{}'.format(seed_list))
