import sys, os
sys.path.append(os.getcwd())

from experiments.resources import *
from source import *

import argparse
import torch
import random
import numpy
import time
import yaml

# Use the GPU/CUDA when available, else use the CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Getting the experiments directory for loading and saving.
directory = os.path.dirname(os.path.abspath(__file__)) + "/"

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ============================================================
# Parsing arguments to construct experiments.
# ============================================================

# Reading in all the experimental configurations and settings.
parser = argparse.ArgumentParser(description="Experiment Runner")
register_configurations(parser)

# Retrieving the dictionary of arguments.
args, args_unknown = parser.parse_known_args()

if args.device is not None:
    device = args.device

if args.fast:  # Makes code non-deterministic (but faster).
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

# ============================================================
# Constructing and executing experiments.
# ============================================================


def _run_experiment(dataset, model, config, random_state):

    # Setting the reproducibility seed in PyTorch.
    if random_state is not None:
        torch.cuda.manual_seed_all(random_state)
        torch.cuda.manual_seed(random_state)
        torch.manual_seed(random_state)
        numpy.random.seed(random_state)
        random.seed(random_state)

    # Generating the custom dataset object.
    training, validation, _ = dataset(pretraining=True, device=device, **config)

    # Creating the base model.
    base_model = model(**config).to(device)

    # Creating the optimizer and learning rate schedule.
    optimizer = optimizer_archive[config["meta_optimizer_name"]](
        list(base_model.parameters()), **config["meta_optimizer_settings"])

    # Creating the meta learning rate scheduler.
    scheduler = scheduler_archive[config["meta_scheduler_name"]](
        optimizer, **config["meta_scheduler_settings"])

    # Creating a results dictionary and recording the start time of the experiment.
    results = {"start_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}
    print("relation", args.dataset, args.model, "seed", str(random_state), "started")

    base_model, meta_history = meta_training_relation(
        base_model, optimizer, scheduler, training, validation,
        gradient_steps=config["meta_gradient_steps"],
        performance_metric=objective_archive[config["evaluation_metric"]],
        **config
    )

    # Recording the learning meta-data.
    results["end_time"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

    # Saving the pretrained model.
    pretrained_directory = "source/models/pretrained/" + args.dataset + "/"
    export_model(base_model, pretrained_directory, args.dataset + "-" + args.model +
                 "-" + str(config["num_ways"]) + "way-" + str(random_state), separate_models_directory=False)

    # Recording the experiment configurations.
    results["experiment_configuration"] = config.copy()

    # Recording the training history.
    results["meta_history"] = meta_history

    # Exporting the results to a json file.
    res_directory = directory + config["output_path"]
    file_name = "pretraining-" + args.dataset + "-" + args.model + "-" + \
                str(config["num_ways"]) + "way-" + str(random_state)

    # Recording information about the experiment.
    results["command"] = "python " + " ".join(sys.argv)  # Recording the python command used.

    export_results(results, res_directory, file_name)

    print("relation", args.dataset, args.model, "seed", str(random_state), "complete")


# Loading the relevant methods configurations file.
dataset_config = yaml.safe_load(open(dataset_config_archive[args.dataset]))
method_config = yaml.safe_load(open(method_config_archive["relation"]))

# Generating the final experimental configurations.
required_args = {"dataset", "model", "seeds", "device"}
config = override_configurations(args, args_unknown, required_args, dataset_config, method_config)

# Retrieving the function for the selected dataset.
dataset_fn = dataset_archive[args.dataset]

# Retrieving the function for the selected model.
model_fn = model_archive[args.model]

# Executing the experiments with the given arguments.
for random_state in args.seeds:
    _run_experiment(dataset_fn, model_fn, config, random_state)
