import torch
import numpy as np
import os


from data import get_the_data
from train import train
from model import define_the_model
from test import test
from optim import define_optimizer
import input_args
from stamp import param_stamp
from store import storeit, label_the_result, storeit_v1


def run(args):
    # Create the range of numbers of the value of noise for training and test
    # The same range value is used for both (a grid with 20*20 = 400 dots)
    list_of_noise_amplitude = np.linspace(0.0, 2.0, num=20, endpoint=False)

    print(label_the_result(args))
    #exit()

    # If the folder of the results doesn't exist make it
    if not os.path.isdir(args.folder_name):
        os.mkdir(args.folder_name)

    # If the "counter" doesn't yet exist make it
    # When this condition meets we also create the list of seeds
    # if not os.path.exists(args.file_directory_counter):
    #     counter = np.array(args.seed)
    #     np.save(args.file_directory_counter, counter)
    #     list_of_seeds = range(args.seed, args.seed+10)
    # else:
    #     counter = np.load(args.file_directory_counter)
    #     list_of_seeds = range(int(counter), 21)

    # if not os.path.exists(args.file_directory_result):
    #     result = np.zeros((args.num_of_seeds, args.train_noise_range, args.test_noise_range))
    #     np.save(args.file_directory_result, result)

    list_of_seeds = range(args.seed, args.seed+10)


    for seed in list_of_seeds:
        args.seed = seed
        print(f"Seed value: {seed}")
        # counter = np.load(args.file_directory_counter)
        np.save(args.file_directory_counter, args.seed)
        for x in list_of_noise_amplitude:
            args.train_noise = x
            for y in list_of_noise_amplitude:
                print(f"Noise Amplitude: {x, y}")
                args.test_noise = y
                if os.path.exists(label_the_result(args) + ".pkl"):
                    continue
                accuracy = single_run(args)
                storeit_v1(args, accuracy)

    print("END OF RUN")



def single_run(args):

    args.param_stamp = param_stamp(args)

    print(args.param_stamp)

    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")
    args.device = device

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)


    train_data, test_data = get_the_data(args)

    model = define_the_model(args)

    optim = define_optimizer(args, model)

    train(args, model, train_data, optim)

    accuracy = test(args, model, test_data)

    return accuracy

## Function for specifying input-options and organizing / checking them
def handle_inputs():
    # Define input options
    parser = input_args.define_args(filename="main", description='Train & test the generative classifier.')
    parser = input_args.add_options(parser)
    # Parse, process (i.e., set defaults for unselected options) and check chosen options
    args = parser.parse_args()
    input_args.set_defaults(args)
    return args



if __name__ == '__main__':
    args = handle_inputs()
    run(args)