import torch
import numpy as np
import os
import pickle

from data import get_the_data
from train import train
from model import define_the_model
from test import test
from optim import define_optimizer
import input_args
from stamp import param_stamp
from store import storeit_v1, label_the_result
from make_data_fig2 import single_run


def run(args):
    for seed in args.list_of_seeds:
        args.seed = seed
        print(f"Seed value: {seed}")
        # counter = np.load(args.file_directory_counter)
        #np.save(args.file_directory_counter, args.seed)
        for x, y in args.list_of_noise:
            args.train_noise, args.test_noise =  x, y
            print(f"Noise Amplitude: {x, y}") 
            if os.path.exists(label_the_result(args) + ".pkl"):
                continue   
            accuracy = single_run(args)
            storeit_v1(args, accuracy)

    print("END OF RUN")

def adjust_args_c2(args):
    args.dataset = "CIFAR20"
    args.column = "c2"
    # args.counter = "counter-tab1-c2"
    # args.result = "result-tab1-c2"
    args.num_classes = 20
    # args.file_directory_counter = args.folder_name + "/" + args.counter  + ".npy"
    # args.file_directory_result = args.folder_name + "/" + args.result  + ".npy"

def adjust_args_c1(args):
    #args.dataset = "CIFAR20"
    # args.counter = "counter-tab1-c1"
    # args.result = "result-tab1-c1"
    #args.num_classes = 20
    # args.file_directory_counter = args.folder_name + "/" + args.counter  + ".npy"
    # args.file_directory_result = args.folder_name + "/" + args.result  + ".npy"
    args.column = "c1"
    args.list_of_seeds = range(args.seed, args.seed+10)
    args.list_of_noise = [(0.0, 0.0), (0.0, 0.5), (0.5, 0.0), (0.5, 0.5), (0.5, 0.9), (0.9, 0.5), (0.9, 0.9), (0.5, 1.5), (0.9, 1.5),\
        (1.5, 0.5), (1.5, 0.9), (1.5, 1.5)]

    
def get_the_directory(i, j, k):
    return "tab1" + "/" + "c1_" + "res" + "-" + "se_" + str(k+11) + "-" + "ntr_" + str(round(i/10, 1))\
        + "-" + "nte_" + str(round(j/10, 1)) + ".pkl"

def meta_run(args):
    args.folder_name = "tab1"
    if not os.path.isdir(args.folder_name):
        os.mkdir(args.folder_name)
    
    args.num_samples = 10000


    # if the data of cifar10 which is result exist skip
    #if not os.path.exists(args.file_directory_counter):
    adjust_args_c1(args)
    run(args)
    
    # create the data pertaining to the cifar20 dataset

    # adjust the args fitting for the cifar20 dataset
    adjust_args_c2(args)

    # run the code
    run(args)


## Function for specifying input-options and organizing / checking them
def handle_inputs():
    # Define input options
    parser = input_args.define_args(filename="main", description='Train & test the generative classifier.')
    parser = input_args.add_options(parser)
    # Parse, process (i.e., set defaults for unselected options) and check chosen options
    args = parser.parse_args()
    input_args.set_defaults(args)
    return args


if __name__ == '__main__':
    args = handle_inputs()
    meta_run(args)