##############################################################################################################################################################
##############################################################################################################################################################
"""
Training script.

Define the hyperparameters to use in the cfg folder.
Then specify which config file to use:

# cfg_file = "cfg/cfg_easy.toml"
# cfg_file = "cfg/cfg_hard.toml"
"""
##############################################################################################################################################################
##############################################################################################################################################################

import os
import csv
import sys
import time
import toml
import torch
import random

import numpy as np
from torch import optim
from pathlib import Path
from shutil import copyfile

import utils as utils
from model import MLP
from dataset import create_dataset, plot_dataset

# reproducibility
# seed = 42
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# np.random.seed(seed)
# random.seed(seed)

##############################################################################################################################################################
##############################################################################################################################################################

def folder_setup(config):

    # define the name of the experiment
    experiment_name = time.strftime("%Y%m%d-%H%M%S") + "_" + config["dataset"]["name"] + "_maxWidth_" + str(np.max(config["model"]["hidden_layers"]))  + "_depth_" + str(len(config["model"]["hidden_layers"])) + "_activation_" + config["model"]["activation"].lower()

    # define the paths to save the experiment
    save_folder = dict()
    save_folder["main"] = Path("results") / experiment_name
    save_folder["checkpoints"] = save_folder["main"] / "checkpoints" 
    save_folder["images"] = save_folder["main"] / "images"
    save_folder["data"] = save_folder["main"] / "data"
    save_folder["scripts"] = save_folder["main"] / "scripts"
    save_folder["logs"] = save_folder["main"] / "logs"

    # create all the folders
    for item in save_folder.values():
        item.mkdir()

    # save the console output to a file and to the console
    sys.stdout = utils.Tee(original_stdout=sys.stdout, file=save_folder["logs"] / "training.log")

    # copy files as a version backup
    # this way we know exactly what we did
    # these can also be loaded automatically for testing the models
    copyfile(Path(__file__).absolute(), save_folder["scripts"] / "train.py")
    copyfile(Path().absolute() / "dataset.py", save_folder["scripts"] / "dataset.py")
    copyfile(Path().absolute() / "model.py", save_folder["scripts"] / "model.py")
    copyfile(Path().absolute() / "utils.py", save_folder["scripts"] / "utils.py")

    # save config file
    # remove device info, as it is not properly saved
    config_to_save = config.copy()
    del config_to_save["device"]
    utils.write_toml_to_file(config_to_save, save_folder["main"] / "cfg.toml")

    return save_folder

##############################################################################################################################################################

def model_setup(config, save_folder):

    # load data
    train_loader = create_dataset(
        which_dataset=config["dataset"]["name"], 
        labels=config["dataset"]["labels"], 
        centers=config["dataset"]["centers"], 
        radii=config["dataset"]["radii"],
        nbr_samples=config["dataset"]["nbr_samples"], 
        dimension=config["dataset"]["dimension"],
        batch_size=config["training"]["batch_size"], 
        split="train", 
        save_folder=save_folder["images"] / "dataset"
    )

    # plot and save the input dataset such that you know what was trained on
    plot_dataset(data_loader=train_loader, save_folder=save_folder["images"] / "dataset")

    # define model
    model = MLP(config["model"], config["dataset"]).to(config["device"])

    # get the optimizer defined in the config file
    # load it from the torch module
    optim_def = getattr(optim, config["training"]["optimizer"])

    # create the optimizer 
    optimizer= optim_def(model.parameters(), lr=config["training"]["learning_rate"])

    return model, optimizer, train_loader

##############################################################################################################################################################

def train_one_epoch(model, optimizer, train_loader, config, nbr_epoch):

    # make sure we are training
    model.train()

    # for each batch
    for x, y in train_loader:

        # push to gpu
        x = x.float().to(config["device"])
        y = y.float().to(config["device"])

        # inference
        prediction = model(x)

        # calculate the loss
        loss = model.loss(prediction=prediction, target=y)

        # set gradients to zero
        model.zero_grad()

        # backprop
        loss.backward()

        # update parameters
        optimizer.step()

    return model

##############################################################################################################################################################

def evaluate(model, train_loader, config, save_folder, nbr_epoch, text):

    # prevent new line at the end of the print
    print(f"Epoch [{nbr_epoch+1}/{config['training']['epochs']}] : Evaluate model ... ", end = '')

    # accumulate loss
    running_loss = 0

    # make sure we are evaluating
    model.eval()

    stock_inputs = []
    stock_targets = []
    stock_predictions = []

    # do not keep track of the gradients
    with torch.no_grad():

        # for each batch
        for x, y in train_loader:

            # push to gpu
            x = x.float().to(config["device"])
            y = y.float().to(config["device"])

            # inference
            prediction = model(x)

            # calculate the loss
            loss = model.loss(prediction=prediction, target=y)

            # accumulate loss
            running_loss += loss.item()

            stock_inputs.extend(x.cpu().numpy())
            stock_targets.extend(y.cpu().numpy())
            stock_predictions.extend(prediction.cpu().numpy())

        # make sure its a numpy array
        stock_inputs = np.array(stock_inputs)
        stock_targets = np.array(stock_targets)
        stock_predictions = np.array(stock_predictions).squeeze()
 
        # model performance for epoch
        running_loss = running_loss / len(train_loader)

        # maximum error, i.e. max(x \in M) |f(x) - F(x)| 
        max_error = np.max(np.abs(stock_targets - stock_predictions))

        # print running losses
        print(f"\tTotal Loss: {running_loss:.5f} \tMax Error: {max_error}")

        # log away the error
        with open(save_folder["logs"] / "loss.log", "a") as f:
            writer = csv.writer(f)
            writer.writerow([running_loss, max_error])

##############################################################################################################################################################

def plot_decision_regions(model, config, save_folder, text):

    # make sure we are evaluating
    model.eval()

    # investigate the decision regions learned by the network
    # we do not neet to keep track of gradients for this
    with torch.no_grad():

        # create the grid on which we apply the classified
        inputs = utils.mesh_grid(min_value=0.0, max_value=1.0, num_points=100, dim=config["dataset"]["dimension"])

        # push data to device
        inputs = inputs.to(config["device"])

        # forward pass (push all points through the network)
        outputs = model(inputs)

        # plot the decision regions
        utils.plot_regions(inputs.cpu(), outputs.cpu(), save_folder["images"], text)

##############################################################################################################################################################

def train(config):

    #########################################################
    # GPU
    #########################################################

    # specify which gpu should be visible
    os.environ["CUDA_VISIBLE_DEVICES"] = config["training"]["gpu"]

    # save the gpu settings
    config["device"] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    #########################################################
    # Setup
    #########################################################

    # create the folders for saving
    save_folder = folder_setup(config)

    # create the model, optimizer and data loader
    model, optimizer, train_loader = model_setup(config, save_folder)

    #########################################################
    # Training
    #########################################################

    # keep track of time
    timer = utils.TrainingTimer()

    # for each epoch
    for nbr_epoch in range(config["training"]["epochs"]):

        # train a single epoch
        model = train_one_epoch(model, optimizer, train_loader, config, nbr_epoch)

        # sometimes do something
        if nbr_epoch % config["training"]["frequency"] == 0:

            # evaluate the model and plot the result
            evaluate(model, train_loader, config, save_folder, nbr_epoch=nbr_epoch, text=f"epoch_{nbr_epoch+1}")

            # plot decision regions of the classifier
            plot_decision_regions(model, config, save_folder, text=f"epoch_{nbr_epoch+1}")

    #########################################################
    # Aftermath
    #########################################################     

    # evlauate the model and plot the result
    evaluate(model, train_loader, config, save_folder, nbr_epoch=nbr_epoch, text="final")

    # save the last model
    torch.save(model.state_dict(), save_folder["checkpoints"] / "last_model.pth")

    print("=" * 37)
    timer.print_end_time()
    print("=" * 37)

    # reset the stdout with the original one
    # this is necessary when the train function is called several times
    # by another script
    sys.stdout = sys.stdout.end()
    
##############################################################################################################################################################

if __name__ == "__main__":

    cfg_file = "cfg/cfg_easy.toml"
    # cfg_file = "cfg/cfg_hard.toml"

    # load the config file
    config = toml.load(cfg_file)
    
    # start the training using the config file
    train(config)