"""
This file holds all code to initialize and train a linear neural network correctly 
it will be imported to the scripts dir. 
"""

import time
import os
import glob

import yaml
import numpy as np
import jax.numpy as jnp
from jax import random
import wandb

from src.neural_nets.linear_nets.mlp_model import init_network_params, update, loss
from src.neural_nets.model_data.make_dataset import  HierarchyDatasetGenerator, HierarchyDataset,  NumpyLoader, Cast
from src.neural_nets.logging.save_functionality import save_model_to_file
from src.neural_nets.metrics.performance_metrics import (accuracy, true_positive_rate, true_negative_rate, 
                                                        level_bias, sibling_metric, chris_sibling_metric, 
                                                        sibling_bias, compute_singular_values)

def load_config(config_path):
    """
    Loads a configuration file in YAML format.
    Args:
        config_path (str): The path to the configuration file.
    Returns:
        dict: The loaded configuration as a dictionary.
    """
    config_file = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                    '..',
                                    'configs', 
                                    config_path))

    with open(config_file, "r") as file:
        return yaml.safe_load(file)

def load_data(batch_size:int, include_head_property:bool, include_bias_input:bool=True):
    """
    Loads data for training a linear neural network.

    Args:
        batch_size (int): The batch size for training.
        include_head_property (bool): Whether to include the head property in the dataset.
        include_bias_input (bool): Whether to include a bias input in the dataset.
    Returns:
        Tuple: A tuple containing the inputs, labels, and training generator.
    """

    hierarchy_generator = HierarchyDatasetGenerator(include_headnode=include_head_property, include_bias_input=include_bias_input)
    inputs, labels = hierarchy_generator.create_dataset()
    train_data = HierarchyDataset((inputs, labels), transform=Cast())
    training_generator = NumpyLoader(train_data, batch_size=batch_size, num_workers=0)
    return inputs, labels, training_generator

def load_imbalanced_data(batch_size:int, inputs, labels):
    """
    Loads data for training a linear neural network.

    Args:
        batch_size (int): The batch size for training.
        inputs (np.array): The input data for training.
        labels (np.array): The labels for the training data.
    Returns:
        Tuple: A tuple containing the inputs, labels, and training generator.
    """
    train_data = HierarchyDataset((inputs, labels), transform=Cast())
    training_generator = NumpyLoader(train_data, batch_size=batch_size, num_workers=0)
    return inputs, labels, training_generator

def initialize_model(input_size:int, hidden_size:int, output_size:int, random_seed:int, scale:float, hidden_bias:bool=False):
    """
    Initializes the model parameters for a linear neural network.

    Args:
        input_size (int): The size of the input layer.
        hidden_size (int): The size of the hidden layer.
        output_size (int): The size of the output layer.
        random_seed (int): The random seed for parameter initialization.
        scale (float): The scale factor for parameter initialization.
    Returns:
        Tuple: The initialized model parameters.
    """
    layer_sizes = [input_size, hidden_size, output_size]
    return init_network_params(
        layer_sizes, random.PRNGKey(random_seed), bias=hidden_bias, scale=scale
    )

def initialize_wandb(step_size:float,
                     layer_sizes:list,
                     num_epochs:int,
                     scale:float,
                     batch_size:int,
                     random_seed:int,
                     choice_temp:float,
                     include_bias_input:bool,
                     include_head_property:bool,
                     net:int,
                     project_name:str="linear network runs",
                     group_name:str="standard"):
    """
    Initializes a Weights & Biases run for training a linear neural network.

    Args:
        step_size (float): The learning rate for training.
        layer_sizes (list): The sizes of the layers in the network.
        num_epochs (int): The number of training epochs.
        scale (float): The scale factor for parameter initialization.
        batch_size (int): The batch size for training.
        random_seed (int): The random seed for parameter initialization.
        choice_temp (float): The temperature parameter for probabilistic choices.
        net (int): The index of the network run.
        include_bias_input (bool): Whether to include a bias input in the dataset.
        include_head_property (bool): Whether to include the head property in the hieararchical dataset.
    Returns:
        wandb.run: The initialized Weights & Biases run.
    """
    return wandb.init(
        project=project_name,
        name=f'linear_network_run_{net}_{time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())}',
        group=group_name,
        config={
            "learning_rate": step_size,
            "architecture": "LinearNetwork",
            "dataset": "hierarchy",
            "epochs": num_epochs,
            "init_scale": scale,
            "n_hidden": layer_sizes[1],
            "batch_size": batch_size,
            "random_seed": random_seed,
            "choice_temp": choice_temp,
            "include_bias_input": include_bias_input,
            "include_head_property": include_head_property,
            "network_run": net 
        },
    )
def get_metrics_dict(params:jnp.array,
                    train_inputs:np.array,
                    train_labels:np.array,
                    choice_temp:float,
                    probabilistic:bool=True,
                    model_name:str="linear_network",
                    bias:bool=False):
    """
    Computes and returns a dictionary of various metrics using the given parameters and data.

    Args:
        params (jnp.array): The parameters of the model.
        train_inputs (np.array): The input data for training.
        train_labels (np.array): The labels for the training data.
        choice_temp (float): The choice temperature for computing accuracy.
        probabilistic (bool, optional): Whether to use probabilistic computation. Defaults to True.

    Returns:
        tuple: A tuple containing the model outputs and a dictionary of computed metrics.
    """
    train_loss = loss(params, train_inputs, train_labels, bias)
    model_outputs, response, train_acc = accuracy(params, train_inputs, train_labels, temprature=choice_temp, probabilistic=probabilistic, bias=bias)
    # if the model name contains the string "imbalance" we only compute loss and accuracy
    if "imbalance" in model_name:
        return model_outputs, {"loss": float(train_loss), "accuracy": float(train_acc)}

    # Compute all required metrics
    # TPR and TNR
    train_tp = true_positive_rate(params, train_inputs, train_labels,
                                  response=response, temprature=choice_temp,
                                  probabilistic=probabilistic, split_partitions=None)
    train_tn = true_negative_rate(params, train_inputs, train_labels,
                                  response=response, temprature=choice_temp,
                                  probabilistic=probabilistic, split_partitions=None)
    
    # TPR and TNR continous
    train_tp_continous = true_positive_rate(params, train_inputs, train_labels, split_partitions=None,
                                            model_outputs=model_outputs, discreet_choices=False)

    train_tn_continous = true_negative_rate(params, train_inputs, train_labels, split_partitions=None,
                                            model_outputs=model_outputs, discreet_choices=False)
    # level wise TPR and TNR
    train_tp_levels = true_positive_rate(params, train_inputs, train_labels,
                                         response=response, temprature=choice_temp,
                                         model_outputs=model_outputs,
                                         probabilistic=probabilistic)
    train_tn_levels = true_negative_rate(params, train_inputs, train_labels,
                                         response=response, temprature=choice_temp,
                                         model_outputs=model_outputs,
                                         probabilistic=probabilistic)
    # continous level wise TPR and TNR
    train_tp_continous_levels = true_positive_rate(params, train_inputs, train_labels,
                                                   model_outputs=model_outputs, discreet_choices=False)

    train_tn_continous_levels = true_negative_rate(params, train_inputs, train_labels,
                                                   model_outputs=model_outputs, discreet_choices=False)

    # accuracy at each level
    train_acc_levels = accuracy(params, train_inputs, train_labels,
                                temprature=choice_temp, response=response,
                                split_partitions=(2,4,8), probabilistic=probabilistic)
    # bias metrics
    bias= level_bias(params, train_inputs, train_labels, 
                     response=response, temprature=choice_temp,
                     probabilistic=probabilistic)

    adjusted_bias = level_bias(params, train_inputs, train_labels,
                               response=response, temprature=choice_temp,
                               adjust=True, probabilistic=probabilistic)
    # sibling metric
    sibling_levels = sibling_metric(params, train_inputs, train_labels,
                                    response, temprature=choice_temp,
                                    probabilistic=probabilistic)
    # chriss sibling metric
    chris_sibling_levels = chris_sibling_metric(params, train_inputs, train_labels,
                                                response, temprature=choice_temp,
                                                probabilistic=probabilistic)
    # sibling_bias
    sibling_bias_levels = sibling_bias(params, train_inputs, train_labels,
                                       response, temprature=choice_temp,
                                       probabilistic=probabilistic)

    singular_values = compute_singular_values(params)

    return model_outputs, {"loss": float(train_loss),
                            "accuracy": float(train_acc),
                            "acc_top_level": train_acc_levels[0],
                            "acc_mid_level": train_acc_levels[1],
                            "acc_bottom_level": train_acc_levels[2],
                            "TPR": float(train_tp),
                            "TNR": float(train_tn),
                            "TPR_continous": float(train_tp_continous),
                            "TNR_continous": float(train_tn_continous),
                            "TPR_top_level": train_tp_levels[0],
                            "TPR_mid_level": train_tp_levels[1],
                            "TPR_bottom_level": train_tp_levels[2],
                            "TNR_top_level": train_tn_levels[0],
                            "TNR_mid_level": train_tn_levels[1],
                            "TNR_bottom_level": train_tn_levels[2],
                            "TPR_continous_top_level": train_tp_continous_levels[0],
                            "TPR_continous_mid_level": train_tp_continous_levels[1],
                            "TPR_continous_bottom_level": train_tp_continous_levels[2],
                            "TNR_continous_top_level": train_tn_continous_levels[0],
                            "TNR_continous_mid_level": train_tn_continous_levels[1],
                            "TNR_continous_bottom_level": train_tn_continous_levels[2],
                            "bias_top_level": bias[0],
                            "bias_mid_level": bias[1],
                            "bias_bottom_level": bias[2],
                            "fraction_of_bias_learned_top_level": adjusted_bias[0],
                            "fraction_of_bias_learned_mid_level": adjusted_bias[1],
                            "fraction_of_bias_learned_bottom_level": adjusted_bias[2],
                            "sibling_metric_top_level": sibling_levels[0],
                            "sibling_metric_mid_level": sibling_levels[1],
                            "sibling_metric_bottom_level": sibling_levels[2],
                            "chris_siblings_top_level": chris_sibling_levels[0],
                            "chris_siblings_mid_level": chris_sibling_levels[1],
                            "chris_siblings_bottom_level": chris_sibling_levels[2],
                            "sibling_bias_top_level": sibling_bias_levels[0],
                            "sibling_bias_mid_level": sibling_bias_levels[1],
                            "sibling_bias_bottom_level": sibling_bias_levels[2],
                            "singular_values": singular_values
                            }

def log_to_wandb(epoch:int,
                 params:jnp.array,
                 run:wandb.run,
                 artifact_creation:bool=False,
                 metric_dict:dict=None,
                 model_outputs:np.array=None):
    """
    Logs metrics and artifacts to Weights & Biases for a given epoch during training.

    Args:
        epoch (int): The current epoch number.
        params (jnp.array): The model parameters.
        run (wandb.run): The wandb run object
        artifact_creation (bool, optional): Whether to create an artifact on wandb for the model. Defaults to False.
        metric_dict (dict, optional): The dictionary of metrics to log. Defaults to None.
        model_outputs (np.array, optional): The model outputs for the current epoch. Defaults to None.
    Returns:
        None
    """

    # Log metrics to wandb
    wandb.log(metric_dict, step=epoch)
    if artifact_creation:
        # Artifact handling
        artifact = wandb.Artifact(
            name=f"model-epoch{epoch}",
            type="model",
            description=f"Model saved at epoch {epoch}",
            metadata={"epoch": epoch, "loss": metric_dict["loss"], "accuracy": metric_dict["accuracy"]}
        )
        weight_path = save_model_to_file(params, os.path.join(wandb.run.dir, f"model_weights_epoch{epoch}.pt"))
        output_path = save_model_to_file(model_outputs, os.path.join(wandb.run.dir, f"model_outputs_epoch{epoch}.pt"))
        artifact.add_file(weight_path)
        artifact.add_file(output_path)
        run.log_artifact(artifact)
        os.remove(weight_path)
        os.remove(output_path)


def save_intermediate_results(net, metrics, metric_labels,
                              singular_values, outputs,
                              params,
                              group_name, model_name):
    # Define the directory and file paths
    results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                        '..', '..', '..', 'results', 'model_runs_linear', 
                                        model_name))
    
    log_dir = os.path.join(results_folder,group_name)
    os.makedirs(log_dir, exist_ok=True)

    metrics_file = os.path.join(log_dir, f"metrics_net_{net}.npy")

    # Save the intermediate metrics and parameters
    np.save(metrics_file, metrics)
    with open(os.path.join(log_dir, f"metric_labels_net_{net}.txt"), 'w') as f:
        f.write('\n'.join(metric_labels))
    # save the svs
    np.save(os.path.join(log_dir, f"singular_values_net_{net}.npy"), singular_values)
    # save the outputs
    np.save(os.path.join(log_dir, f"outputs_net_{net}.npy"), outputs)
    # save the params
    params_1 = params[0]
    np.save(os.path.join(log_dir, f"W1_net_{net}.npy"), params_1)
    if params[1] is not None:
        params_2 = params[1]
        np.save(os.path.join(log_dir, f"W2_net_{net}.npy"), params_2)
    else:
        params_2 = None
        np.save(os.path.join(log_dir, f"W2_net_{net}.npy"), params_2)
    if params[2] is not None:
        hidden_layer_bias = params[2]
        np.save(os.path.join(log_dir, f"hidden_bias_net_{net}.npy"), hidden_layer_bias)


# This function aggregates results from multiple runs and saves them
def aggregate_results(group_name, model_name, num_epochs, num_runs, log_interval, hidden_layer_bias=False):
    """
    Aggregates the results from the intermediate saves

    Args:
        group_name (str): The name of the group.
        model_name (str): The name of the model.
        num_epochs (int): The number of training epochs.
        num_runs (int): The number of network runs.
    Returns:
        aggregated_metrics, log_dir
    """
    results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                        '..', '..', '..', 'results', 'model_runs_linear', 
                                        model_name))

    log_dir = os.path.join(results_folder,group_name)

    aggregated_metrics = {}
    aggregated_singular_values = []
    aggregated_outputs = []
    aggregated_w1 = []
    aggregated_w2 = []
    aggregated_hidden_layer_bias = []

    # Load metric labels from the first network (assuming all networks have the same labels)
    with open(os.path.join(log_dir, "metric_labels_net_0.txt"), 'r') as f:
        metric_labels = f.read().split('\n')
    
    # Initialize the aggregated metrics dictionary
    for label in metric_labels:
        aggregated_metrics[label] = []

    for net in range(num_runs):
        print(f'Aggregating {group_name} network {net}')

        # Load the metrics for this network
        metrics_file = os.path.join(log_dir, f"metrics_net_{net}.npy")
        # load the singular values
        singular_values_file = os.path.join(log_dir, f"singular_values_net_{net}.npy")
        # load the outputs
        outputs_file = os.path.join(log_dir, f"outputs_net_{net}.npy")
        # load the params
        w_1_file = os.path.join(log_dir, f"W1_net_{net}.npy")
        w_2_file = os.path.join(log_dir, f"W2_net_{net}.npy")
        hidden_layer_bias_file = os.path.join(log_dir, f"hidden_bias_net_{net}.npy")

        net_metrics = np.load(metrics_file, allow_pickle=True)
        # Append the loaded metrics and singular values to the aggregation
        for i, label in enumerate(metric_labels):
            aggregated_metrics[label].append(net_metrics[:, i])

        net_singular_values = np.load(singular_values_file, allow_pickle=True)
        aggregated_singular_values.append(net_singular_values)

        net_outputs = np.load(outputs_file, allow_pickle=True)
        # do the same for the outputs
        aggregated_outputs.append(net_outputs)

        w_1 = np.load(w_1_file, allow_pickle=True)
        # do the same for the params
        aggregated_w1.append(w_1)

        if model_name != "shallow_net":
            w_2 = np.load(w_2_file, allow_pickle=True)
            os.remove(w_2_file)
            aggregated_w2.append(w_2)

        if hidden_layer_bias:
            hidden_layer_bias_data = np.load(hidden_layer_bias_file, allow_pickle=True)
            os.remove(hidden_layer_bias_file)
            aggregated_hidden_layer_bias.append(hidden_layer_bias_data)

        # Remove the intermediate files
        os.remove(metrics_file)
        os.remove(singular_values_file)
        os.remove(outputs_file)
        os.remove(w_1_file)
        # remove all the txt files ending on .txt
        [os.remove(file) for file in glob.glob(os.path.join(log_dir, "*.txt"))]

    # Stack the metrics and singular values across networks
    for label in metric_labels:
        aggregated_metrics[label] = np.stack(aggregated_metrics[label], axis=0)

    aggregated_singular_values = np.stack(aggregated_singular_values, axis=0)

    # stack outputs and params
    aggregated_outputs = np.stack(aggregated_outputs, axis=0)
    aggregated_w1 = np.stack(aggregated_w1, axis=0) 
    if model_name != "shallow_net":
        aggregated_w2 = np.stack(aggregated_w2, axis=0)
    if hidden_layer_bias:
        aggregated_hidden_layer_bias = np.stack(aggregated_hidden_layer_bias, axis=0)

    return aggregated_metrics, aggregated_singular_values, aggregated_outputs, aggregated_w1, aggregated_w2, aggregated_hidden_layer_bias, log_dir

def save_aggregated_results(aggregated_metrics, aggregated_singular_values,
                            aggre_outputs, aggre_w1, aggre_w2, agg_bias, log_dir):
    """
    Saves the aggregated results of the training process.

    Args:
        aggregated_metrics (Dict[str, List[np.ndarray]]): A dictionary containing the aggregated metrics across all networks.
        aggregated_singular_values (np.ndarray): The aggregated singular values across all networks.
        log_dir (str): The directory where the results will be saved.
    Returns:
        None
    """
    # Save the final aggregated results
    for label, metrics in aggregated_metrics.items():
        np.save(os.path.join(log_dir, f"{label}.npy"), metrics)

    np.save(os.path.join(log_dir, "singular_values.npy"), aggregated_singular_values)
    # save the outputs
    np.save(os.path.join(log_dir, "outputs.npy"), aggre_outputs)
    # save the params
    np.save(os.path.join(log_dir, "W1.npy"), aggre_w1)
    np.save(os.path.join(log_dir, "W2.npy"), aggre_w2)
    np.save(os.path.join(log_dir, "hidden_layer_bias.npy"), agg_bias)


def train(num_epochs:int, params:jnp.array, train_inputs:np.array,
          train_labels:np.array, step_size:float, training_generator,
          log_interval:int=100, print_interval:int=10000,
          choice_temp:float=1.0, run:wandb.run=None, probabilistic:bool=True,
          verbose:bool=True, log_locally:bool=False, net:int=0,
          group_name:str="standard", model_name:str="linear_network", hidden_layer_bias:bool=False):
    """
    Trains a linear neural network for the specified number of epochs.
    Args:
        num_epochs (int): The number of training epochs.
        params (jnp.array): The initial model parameters.
        train_inputs (np.array): The training inputs.
        train_labels (np.array): The training labels.
        step_size (float): The learning rate for training.
        training_generator: The data generator for training.
        log_interval (int, optional): The interval for logging metrics to Weights & Biases. Defaults to 100.
        print_interval (int, optional): The interval for printing model performance. Defaults to 500.
        choice_temp (float, optional): The temperature parameter for probabilistic choices. Defaults to 1.0.
        run (wandb.run, optional): The wandb run object. Defaults to None.
        probabilistic (bool, optional): Whether to use probabilistic choices for the model. Defaults to True.
        log_locally (bool, optional): Whether to log locally. Defaults to False.
        net (int, optional): The index of the network run. Defaults to 0.
        group_name (str, optional): The name of the group. Defaults to "standard".
        model_name (str, optional): The name of the model. Defaults to "linear_network".

    Returns:
        None
    """

    # Initialize the metrics array and singular values array
    num_metrics = None  # We will set this once we get the first metric_dict
    metrics_array = None
    singular_values_array = None
    metric_labels = None

    for epoch in range(num_epochs):
        start_time = time.time()
        for x, y in training_generator:
            params = update(params, x, y, step_size, bias=hidden_layer_bias)
        epoch_time = round(time.time() - start_time, 4)

        # Print model performance at intervals if specified and if is not none
        if verbose and epoch % print_interval == 0:
            train_loss = loss(params, train_inputs, train_labels, hidden_layer_bias)
            _, _, train_acc = accuracy(params, train_inputs, train_labels, temprature=choice_temp, bias=hidden_layer_bias, probabilistic=probabilistic)
            print(f"Epoch {epoch} in {epoch_time} sec")
            print(f"Training set loss {round(train_loss, 4)}")
            print(f"Training set accuracy {round(train_acc, 4)}")

        # Logging and model saving at intervals
        if epoch % log_interval == 0:
            model_outputs, metric_dict = get_metrics_dict(params, train_inputs, train_labels, choice_temp, probabilistic=probabilistic, model_name=model_name, bias=hidden_layer_bias)

            # optional log to wandb if the wandb object is given
            if run is not None:
                log_to_wandb(epoch, params, run, metric_dict=metric_dict, model_outputs=model_outputs)

            # optional log locally
            elif log_locally:

                # Initialize arrays if this is the first logging interval
                if metrics_array is None:
                    # Exclude 'singular_values' from num_metrics
                    num_metrics = len(metric_dict) - 1
                    # num_metrics = len(metric_dict)
                    metrics_array = np.zeros((num_epochs // log_interval, num_metrics))
                    metric_labels = [key for key in metric_dict if key != 'singular_values']
                    # Initialize singular_values array
                    singular_values_array = np.zeros((num_epochs // log_interval, len(metric_dict['singular_values'])))
                    # initialize outputs array
                    outputs_array = np.zeros((num_epochs // log_interval, *model_outputs.shape))
                    weight_1_array = np.zeros((num_epochs // log_interval, *params[0].shape))
                    if len(params) > 1:
                        weight_2_array = np.zeros((num_epochs // log_interval, *params[1].shape))
                        
                    if len(params) > 1 and len(params[1]) == 2:
                        weight_2_array = np.zeros((num_epochs // log_interval, *params[1][0].shape))
                    # save the hidden layer bias
                    if len(params) > 1 and len(params[1]) == 2:
                        hidden_layer_bias_array = np.zeros((num_epochs // log_interval, *params[1][1].shape))
                
                # Save scalar metrics
                metrics_row = [metric_dict[label] for label in metric_labels]
                metrics_array[epoch // log_interval] = metrics_row
                
                # # Save singular_values
                singular_values_array[epoch // log_interval] = metric_dict['singular_values']
                # save the outputs
                outputs_array[epoch // log_interval] = model_outputs
                # save the params
                weight_1_array[epoch // log_interval] = params[0]
                # a mess of conditionals for the different cases here
                if len(params) > 1 and len(params[1]) != 2:
                    weight_2_array[epoch // log_interval] = params[1]
                elif len(params) > 1 and len(params[1]) == 2:
                    weight_2_array[epoch // log_interval] = params[1][0]
                else:
                    weight_2_array = None
                # save the hidden layer bias
                if len(params) > 1 and len(params[1]) == 2:
                    hidden_layer_bias_array[epoch // log_interval] = params[1][1]
                else:
                    hidden_layer_bias_array = None
                

    if log_locally:
        save_intermediate_results(net, metrics_array, metric_labels, singular_values_array,
                                  outputs=outputs_array, params=[weight_1_array, weight_2_array, hidden_layer_bias_array],
                                   group_name=group_name, model_name=model_name)
    if run is not None:
        wandb.finish()