##############################################################################################################################################################
##############################################################################################################################################################
"""
Helperfunctions
"""
##############################################################################################################################################################
##############################################################################################################################################################

import toml
import time
import torch
import datetime
import dateutil

import numpy as np

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')

##############################################################################################################################################################
##############################################################################################################################################################

class Tee(object):
    """
    Class to make it possible to print text to the console and also write the 
    output to a file.
    """

    def __init__(self, original_stdout, file):

        # keep the original stdout
        self.original_stdout = original_stdout

        # the file to write to
        self.log_file_handler= open(file, 'w')

        # all the files the print should be saved to
        self.files = [self.original_stdout, self.log_file_handler]

    def write(self, obj):

        # for each file
        for f in self.files:

            # write to the file
            f.write(obj)

            # If you want the output to be visible immediately
            f.flush() 

    def flush(self):

        # for each file
        for f in self.files:

            # If you want the output to be visible immediately
            f.flush()

    def end(self):

        # close the file
        self.log_file_handler.close()

        # return the original stdout
        return self.original_stdout

##############################################################################################################################################################

def write_toml_to_file(cfg, file_path):
    """
    Writ the parser to a file.
    """

    with open(file_path, 'w') as output_file:
        toml.dump(cfg, output_file)

    print('=' * 57)
    print("Config file saved: ", file_path)
    print('=' * 57)

##############################################################################################################################################################

class TrainingTimer():
    """
    Keep track of training times.
    """

    def __init__(self):

        # get the current time
        self.start = datetime.datetime.fromtimestamp(time.time())
        self.eval_start = datetime.datetime.fromtimestamp(time.time())

        # print it human readable
        print("Training start: ", self.start)
        print('=' * 57)

    def print_end_time(self):

        # get datetimes for simplicity
        datetime_now = datetime.datetime.fromtimestamp(time.time())

        print("Training finish: ", datetime_now)

        # compute the time different between now and the input
        rd = dateutil.relativedelta.relativedelta(datetime_now, self.start)

        # print the duration in human readable
        print(f"Training duration: {rd.hours} hours, {rd.minutes} minutes, {rd.seconds} seconds")

    
    def print_time_delta(self):

        # get datetimes for simplicity
        datetime_now = datetime.datetime.fromtimestamp(time.time())

        # compute the time different between now and the input
        rd = dateutil.relativedelta.relativedelta(datetime_now, self.eval_start)

        # print the duration in human readable
        print(f"Duration since last evaluation: {rd.hours} hours, {rd.minutes} minutes, {rd.seconds} seconds")
        print('=' * 57)

        # update starting time
        self.eval_start = datetime.datetime.fromtimestamp(time.time())

##############################################################################################################################################################

def plot_result(x, targets, predictions, save_folder=None, text=""):

    # plot the predictions against the targets
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.set_title("Prediction vs. Targets")
    ax.scatter(x[:,0], x[:,1], predictions, color="blue", label="prediction")
    ax.scatter(x[:,0], x[:,1], targets, color="red", label="target")
    ax.set_xlabel(r'$\mathregular{x}_1$')
    ax.set_ylabel(r'$\mathregular{x}_2$')
    ax.set_zlabel("y")
    ax.grid()
    ax.legend()


    if save_folder is not None:

        # save the input samples
        fig.savefig(save_folder["images"] / f"prediction_vs_target{text}.png")
        
        # make sure to close the figure
        plt.close(fig)

    else:
        # or plot it
        plt.show()

##############################################################################################################################################################

def plot_regions(inputs, predicted, save_folder, text):

    save_path = save_folder / "decision"
    save_path.mkdir(exist_ok=True)

    unique_labels = np.unique(np.round(predicted))

    # create the 2d figure
    fig, ax = plt.subplots()

    colors = ['royalblue', 'firebrick', "green", "yellow", "orange", "black"]

    for idx, label in enumerate(unique_labels):

        # plot both decision regions based on the two boolean masks
        ax.scatter(inputs[np.round(predicted).squeeze()==label, 0], inputs[np.round(predicted).squeeze()==label, 1], label=int(label), c=colors[idx])

    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_aspect('equal')
    ax.set_xlabel(r'$\mathregular{x}_1$')
    ax.set_ylabel(r'$\mathregular{x}_2$')
    legend = plt.legend(frameon = 1, fontsize=11, loc="upper left")
    frame = legend.get_frame()
    frame.set_facecolor('white')
    for handle in legend.legendHandles:
        handle.set_sizes([25.0])

    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(11)

    plt.savefig(save_path / f'decision_regions_{text}.png', bbox_inches='tight')
    plt.close(fig)

##############################################################################################################################################################

def mesh_grid(min_value, max_value, num_points, dim):
    """
    Define a two- or three-dimensional grid which will be used to check the learned decision region.

    :param min_value:
    Minimum value for each grid dimension.

    :param max_value:
    Maximum value for each grid dimension.

    :param num_points:
    Number of points for each grid dimension.

    :param dim:
    Dimension of the grid: 2 or 3.

    :return:
    2 or 3 dimensional grid.
    """

    # define num_points equally spaced points between min_value and max_value
    x = np.linspace(min_value, max_value, num_points)

    # it we want a 3-dimensional grid, then we use three such grids (cube)
    if dim == 3:
        grid = np.meshgrid(x, x, x, indexing='ij')

    # otherwise only two (square)
    else:
        grid = np.meshgrid(x, x, indexing='ij')

    # stack the result
    grid = np.stack(grid, axis=-1)

    # we only want a dim-dimensional array (to be equivalent with the train/test data structure
    grid = grid.reshape(-1, dim)

    # return a torch vector of this
    return torch.from_numpy(grid).float()

##############################################################################################################################################################
