import argparse
import os


def parse_args():
    parser = argparse.ArgumentParser(description="Pytorch implementation.")
    folder_text = 'simulations_gap'
    task_test_size = 10000
    buff_size = 200000  # 38000*200=7600000
    parser.add_argument('--train_task_number', type=int, default=2)  # 40
    parser.add_argument('--observation_dims', type=int, default=27)  # 17
    parser.add_argument('--action_dims', type=int, default=8)  # 6
    parser.add_argument('--replay_buffer_size', type=int, default=buff_size)
    parser.add_argument('--trj_number', type=int, default=50)
    parser.add_argument('--task_batch_size', type=int, default=128)
    parser.add_argument('--hidden_feature_number', type=int, default=200)
    parser.add_argument('--hidden_layer_number', type=int, default=3)
    parser.add_argument('--learning_rate', type=float, default=0.0001)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--alpha', type=float, default=0.05)
    parser.add_argument('--training_iteration_number', type=int, default=20000)
    parser.add_argument('--training_task_number', type=int, default=20)
    parser.add_argument('--task_training_iterations', type=int, default=50)
    parser.add_argument('--cuda', type=bool, default=True)
    parser.add_argument('--cuda_index', type=int, default=0)
    parser.add_argument('--save_folder', type=str, default=folder_text + '/data_folder')
    parser.add_argument('--data_dir', type=str, default='offlinerl/data/ant-dir')  # 'offlinerl/data/walker_rand_params'
    # parser.add_argument('--checkpoint_paths', type=str,
    #                     default=[folder_text + '/data_folder_28/checkpoint_Iter_999',
    #                              folder_text + '/data_folder_28/test_checkpoint_Iter_990'])
    parser.add_argument('--base_folder', type=str, default=folder_text + '/saved_baseline')
    parser.add_argument('--test_task_iter_num', type=int, default=25)
    parser.add_argument('--task_test_sample_number', type=int, default=task_test_size)
    parser.add_argument('--task_test_train_sample_number', type=int, default=buff_size-task_test_size)

    parser.add_argument('--save_interval', type=int, default=10)
    parser.add_argument('--meta_test_interval', type=int, default=10000000)
    parser.add_argument('--meta_test_number', type=int, default=1)

    if not os.path.exists(folder_text):
        os.mkdir(folder_text)

    args = parser.parse_args()
    # setattr(args, 'test_batch_size', 20 - args.task_batch_size)

    return args


def save_module(arg, save_fld, iters, save_vector, vector_content):
    for i in range(len(vector_content)):
        if vector_content[i] == "training_loss":
            file = save_fld + '/Training_Loss.txt'
            gamma_save = save_vector[i]
            if not os.path.exists(file):
                f = open(file, "w+")
                f.write(str(gamma_save))
                f.close()
            else:
                f = open(file, "a")
                f.write(','+str(gamma_save))
                f.close()
            file = save_fld + '/loss_iterations.txt'
            if not os.path.exists(file):
                f = open(file, "w+")
                f.write(str(iters))
                f.close()
            else:
                f = open(file, "a")
                f.write(',' + str(iters))
                f.close()
        elif vector_content[i] == "meta_test_loss":
            file = save_fld + '/meta_test_loss.txt'
            acc_save = save_vector[i]
            if not os.path.exists(file):
                f = open(file, "w+")
                f.write(str(acc_save))
                f.close()
            else:
                f = open(file, "a")
                f.write(','+str(acc_save))
                f.close()
            file = save_fld + '/meta_test_loss_iterations.txt'
            if not os.path.exists(file):
                f = open(file, "w+")
                f.write(str(iters))
                f.close()
            else:
                f = open(file, "a")
                f.write(',' + str(iters))
                f.close()
        elif vector_content[i] == "meta_train_loss":
            file = save_fld + '/meta_train_loss.txt'
            acc_save = save_vector[i]
            if not os.path.exists(file):
                f = open(file, "w+")
                f.write(str(acc_save))
                f.close()
            else:
                f = open(file, "a")
                f.write(','+str(acc_save))
                f.close()
            file = save_fld + '/meta_train_loss_iterations.txt'
            if not os.path.exists(file):
                f = open(file, "w+")
                f.write(str(iters))
                f.close()
            else:
                f = open(file, "a")
                f.write(',' + str(iters))
                f.close()


def parse_args_fid():
    parser = argparse.ArgumentParser(description="Pytorch implementation.")
    parser.add_argument('--window', type=int, default=1000)
    parser.add_argument('--pick_data_per', type=int, default=1)

    args = parser.parse_args()

    file = open("figure_setup.txt", "r")  # figure_setup.txt, figure_setup_cum_loss.txt
    labels = []
    legend_text = []
    iteration_files = []
    fid_files = []
    k = 0
    file = file.read().splitlines()
    for line in file:
        fields = line.split(";")
        if k == 0:
            labels.append(fields[0])
            labels.append(fields[1])
            labels.append(fields[2])
        elif k == 1:
            legend_text = fields
        elif k == 2:
            iteration_files = fields
        elif k == 3:
            fid_files = fields
        k += 1
    setattr(args, 'labels', labels)
    setattr(args, 'legend_text', legend_text)
    setattr(args, 'iteration_files', iteration_files)
    setattr(args, 'fid_files', fid_files)

    return args
