import sys, os
sys.path.append(os.getcwd())

from experiments.resources import *
from source import *

import argparse
import torch
import random
import numpy
import time
import yaml

# Use the GPU/CUDA when available, else use the CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Getting the experiments directory for loading and saving.
directory = os.path.dirname(os.path.abspath(__file__)) + "/"

# Ensuring PyTorch gives deterministic output.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ============================================================
# Parsing arguments to construct experiments.
# ============================================================

parser = argparse.ArgumentParser(description="Experiment Runner")

# Experiment settings.
parser.add_argument("--dataset", required=True, type=str)
parser.add_argument("--model", required=True, type=str)
parser.add_argument("--seeds", required=True, type=int, nargs="+")
parser.add_argument("--device", required=False, type=str)

# Registering all optional configuration hyper-parameters.
register_configurations(parser)

# Retrieving the dictionary of arguments.
args, args_unknown = parser.parse_known_args()

if args.device is not None:
    device = args.device

# ============================================================
# Constructing and executing experiments.
# ============================================================


def _run_experiment(dataset, model, config, random_state):

    # Setting the reproducibility seed in PyTorch.
    if random_state is not None:
        torch.cuda.manual_seed(random_state)
        torch.manual_seed(random_state)
        numpy.random.seed(random_state)
        random.seed(random_state)

    # Generating the PyTorch dataset.
    training, validation, testing = dataset()

    # Creating the base model.
    base_model = model(**config["base_model_settings"]).to(device)

    # Creating the base model's optimizer.
    base_optimizer = optimizer_archive[config["base_optimizer_name"]](
        base_model.parameters(), **config["base_optimizer_settings"])

    # Creating the meta learned loss function.
    meta_model = LossNetwork(
        input_dim=config["base_model_settings"]["output_dim"],
        logits_to_prob=True if config["task_type"] == "classification" else False,
        one_hot_encode=True if config["task_type"] == "classification" else False
    ).to(device)

    # Creating the meta model's offline optimizers.
    offline_meta_optimizer = optimizer_archive[config["offline_meta_optimizer_name"]](
        meta_model.parameters(), **config["offline_meta_optimizer_settings"])

    # Defining the output results directory and file name.
    res_directory = directory + config["output_path"]
    file_name = "offline-" + args.dataset + "-" + args.model + "-" + str(random_state)

    print("offline", args.dataset, args.model, "seed", str(random_state), "started")

    # Creating a dictionary for recording experiment results.
    results = {"start_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}

    # Performing offline loss function learning to initalize the loss network.
    meta_model, _, offline_training_history, offline_meta_model_history = unrolled_differentiation(
        meta_model, offline_meta_optimizer, base_model, base_optimizer, training, validation,
        gradient_steps=config["offline_meta_gradient_steps"],
        inner_gradient_steps=config["offline_inner_gradient_steps"],
        batch_size=config["base_batch_size"],
        task_loss_fn=objective_archive[config["task_loss"]],
        performance_metric=objective_archive[config["performance_metric"]],
        verbose=config["verbose"], device=device, offline=True
    )

    # Exporting the loss functions created in the intialization phase.
    export_online_loss(offline_meta_model_history, res_directory, file_name + "-offline")

    # Recording the end time of the initialization phase.
    results["train_time"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

    # Performing online loss function learning to train the base and meta networks simultaneously.
    training_history = backpropagation(
        base_model, base_optimizer, None, training,
        gradient_steps=config["online_meta_gradient_steps"],
        batch_size=config["base_batch_size"],
        loss_function=meta_model,
        performance_metric=objective_archive[config["performance_metric"]],
        verbose=config["verbose"], device=device, terminate_divergence=False
    )

    # Recording the learning meta-data.
    results["test_time"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

    # Computing the final inference error rate of our trained model.
    results["training_inference"] = evaluate(
        model=base_model, task=training, device=device,
        performance_metric=objective_archive[config["performance_metric"]]
    )
    results["testing_inference"] = evaluate(
        model=base_model, task=testing, device=device,
        performance_metric=objective_archive[config["performance_metric"]]
    )

    # Recording the experiment configurations.
    results["experiment_configuration"] = config.copy()

    # Recording the training history.
    results["offline_training_history"] = offline_training_history
    results["online_training_history"] = training_history

    # Exporting the results to a json file.
    export_results(results, res_directory, file_name)
    export_model(base_model, res_directory, file_name)

    print("offline", args.dataset, args.model, "seed", str(random_state), "complete")


# Opening the relevant configurations file.
with open(dataset_archive[args.dataset]["config"]) as file:
    config = yaml.safe_load(file)

required_args = {"dataset", "model", "seeds", "device"}
override_configurations(args, args_unknown, required_args, config)

# Retrieving the function for the selected dataset.
dataset_fn = dataset_archive[args.dataset]["data"]

# Retrieving the function for the selected model.
model_fn = model_archive[args.model]

# Executing the experiments with the given arguments.
for random_state in args.seeds:
    _run_experiment(dataset_fn, model_fn, config, random_state)
