##############################################################################################################################################################
##############################################################################################################################################################
"""
This scripts takes an experiment and saves images for the transformations learned by the neural network.
After each layer and each activtation function, we plot the visualization of the input data.

To apply this to an experiment (to be located inside results/), specify the folder name in the following location:

# the experiment to investigate
experiment = ""
"""
##############################################################################################################################################################
##############################################################################################################################################################

import os
import sys
import toml
import torch
import numpy as np
from pathlib import Path
from importlib import import_module
from collections import defaultdict

from matplotlib import pyplot as plt
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')

from model import MLP
from train import plot_decision_regions
from dataset import plot_dataset

##############################################################################################################################################################
##############################################################################################################################################################

def get_folder(experiment):

    folders = dict()
    folders["experiment"] = Path("results") / experiment
    folders["scripts"] = folders["experiment"] / "scripts"
    folders["logs"] = folders["experiment"] / "logs"
    folders["checkpoint"] = folders["experiment"] / "checkpoints" / "last_model.pth"
    folders["images"] = folders["experiment"] / "images"
    folders["config"] = folders["experiment"] / "cfg.toml"

    return folders

##############################################################################################################################################################        

def get_layers(config, folders):

    # define model
    model = MLP(config["model"], config["dataset"]).to(config["device"])

    # load the parameters from the file
    checkpoint = torch.load(folders["checkpoint"], map_location=config["device"])

    # load the model weights
    model.load_state_dict(checkpoint) 

    # make sure the model is in eval mode
    model.eval()

    # get all the layers from the model
    model_layers = [module for module in model.regression.modules() if type(module) != torch.nn.Sequential]

    return model, model_layers

##############################################################################################################################################################        

def get_train_loaders(config, folders):

    # get path to the loader script
    model_file = folders["scripts"] / "dataset"

    # replace the / by .
    string_model_file = str(model_file).replace("/",".")

    # load the file
    loader_def = import_module(string_model_file, package="balls")

    # load the functions from the file
    create_dataset = getattr(loader_def, "create_dataset")

    # load data
    train_loader = create_dataset(
        which_dataset=config["dataset"]["name"], 
        labels=config["dataset"]["labels"], 
        centers=config["dataset"]["centers"], 
        radii=config["dataset"]["radii"],
        nbr_samples=config["dataset"]["nbr_samples"], 
        dimension=config["dataset"]["dimension"],
        batch_size=config["training"]["batch_size"], 
        split="train"
    )

    plot_dataset(data_loader=train_loader, save_folder=folders["images"] / "dataset")

    return train_loader

##############################################################################################################################################################

def visualize_layers(model_layers, train_loader, config, save_folder):

    # keep track of the results after each layer
    keep_track = defaultdict(list)

    # keep track of the labels
    labels = []

    # do not keep track of the gradients
    with torch.no_grad():

        # for each batch
        for x, y in train_loader:

            # push to gpu
            x = x.float().to(config["device"])

            # for each layer of the model
            for idx, layer in enumerate(model_layers):

                # make sure we are in eval mode
                layer.eval()

                # apply the current layer to the current input
                x = layer(x)

                # keep track of the values
                keep_track[idx].extend(x.cpu().numpy())

            # keep track of the labels
            labels.extend(y.numpy())

    # make sure its a numpy array
    labels = np.array(labels)

    # get the unique labels
    unique_labels = np.unique(labels)

    # define some nice colors
    colors = ['royalblue', 'firebrick']

    # iterate over all layer values again
    for layer_idx, (key, val) in enumerate(keep_track.items()):

        if layer_idx == len(model_layers)-1:
            continue

        # make sure they are numpy arrays
        keep_track[key] = np.array(val)

        fig, ax = plt.subplots()

        for idx, label in enumerate(unique_labels):
            ax.scatter(keep_track[key][labels==label,0], keep_track[key][labels==label,1], label=label, c=colors[idx], s=5)

        ax.axis('square')
        ax.set_xlabel(r'$\mathregular{x}_1$')
        ax.set_ylabel(r'$\mathregular{x}_2$')
        layer_name = str(model_layers[layer_idx]).split("(")[0] if layer_idx != 0 else "Input"
        ax.set_title("Layer definition: " + layer_name + " - Layer idx: " + str(layer_idx))
        legend = plt.legend(frameon = 1, fontsize=11)
        frame = legend.get_frame()
        frame.set_facecolor('white')
        for handle in legend.legendHandles:
            handle.set_sizes([25.0])

        for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
            item.set_fontsize(11)

        plt.savefig(save_folder / f'layer_{layer_idx}.png', bbox_inches='tight')
        plt.close(fig)


##############################################################################################################################################################

if __name__ == "__main__":

    # the experiment to investigate
    experiment = "20210531-070008_sphere_maxWidth_2_depth_3_activation_relu"

     # specify which gpu should be visible
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    # the main folder containing all results
    result_folder = Path("results")

    # get the path to all the folder and files
    folders = get_folder(experiment)

    # check if the folder exists
    if not folders["experiment"].is_dir():
        print("The specified experiment does not exist: Exit script.")
        sys.exit()

    # load the config file
    config = toml.load(folders["config"])

    # define the device
    config["device"] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # get the layers of the model
    model, model_layers = get_layers(config, folders)

    # get the training data
    train_loader = get_train_loaders(config, folders)

    # define the folder to save the images
    save_folder = folders["images"] / "layers"
    save_folder.mkdir(exist_ok=True)

    # visualize the model transformations
    visualize_layers(model_layers, train_loader, config, save_folder)

    # plot the decision region again
    plot_decision_regions(model, config, folders, text=f"eval")