import os
import time
import sys

import itertools
from multiprocessing import Pool
import argparse
import numpy as np
import jax.numpy as jnp
from jax import random
from jax import pmap
from src.neural_nets.linear_nets.training_modules import(load_config,
                                                        train,
                                                        initialize_model,
                                                        initialize_wandb,
                                                        load_data,
                                                        aggregate_results,
                                                        save_aggregated_results)

# script that runs the training of our linear network

os.environ["WANDB_SILENT"]="true"

def run_single_experiment(config, net, inputs=None, labels=None, train_generator=None):
    """
    Runs an experiment using the given configuration and arguments.

    Args:
        config (dict): A dictionary containing the configuration parameters for the experiment.
        net (int): The index of the network run.
        rng_key (jax.random.PRNGKey): The random key for the experiment.
        inputs (np.ndarray, optional): The input data. Defaults to None.
        labels (np.ndarray, optional): The target labels. Defaults to None.
        training_generator (torch.utils.data.DataLoader, optional): The training data generator. Defaults to None.
    Returns:
        None
    """
    # Extracting hyperparameters from the config
    batch_size = config['batch_size']
    include_head_property = config['include_head_property']
    include_bias_input = config['include_bias_input']

    step_size = config['step_size']
    num_epochs = config['num_epochs']
    scale = config['scale']
    choice_temp = config['choice_temp']
    log_interval = config['log_interval']
    print_interval = config['print_interval']
    wandb_project_name = config['wandb_project_name']
    group_name = config['group_name']
    model_name = config['model_name']
    verbose = config['verbose']

    print(f"Starting network {net}...")

    random_seed = np.random.randint(0, 1000000)
    # initialise model
    params = initialize_model(inputs.shape[1], config['layer_sizes'][1],
                              labels.shape[1], random_seed, scale)
    
    # initialise wandb
    if config['log_to_wandb']:
        wandb_run = initialize_wandb(step_size, config['layer_sizes'], num_epochs,
                                    scale, batch_size, random_seed, choice_temp,
                                    include_bias_input, include_head_property,net=net,
                                    project_name=wandb_project_name, group_name=group_name)
    else:
        wandb_run = None
    
    # train the model and log the results locally
    if config['log_locally']:
        train(num_epochs, params, inputs, labels, step_size,
                    train_generator, log_interval, print_interval,
                    choice_temp, wandb_run, verbose=verbose,
                    log_locally=config['log_locally'], net=net,
                    group_name=group_name, model_name=model_name)
        # return metrics_log #, params_log
    
    # # if we are not logging locally, just train the model
    # train(num_epochs, params, inputs, labels, step_size,
    #         train_generator, log_interval, print_interval,
    #         choice_temp, wandb_run, verbose=verbose)

def parallel_run_experiments(config, inputs=None, labels=None, train_generator=None, n_processes=5):
    """
    Runs multiple experiments in parallel using the given configuration.

    Args:
        config (dict): A dictionary containing the configuration parameters for the experiments.
        inputs (optional): The input data for the experiments.
        labels (optional): The labels for the input data.
        train_generator (optional): The training generator for the experiments.
        n_processes (int, optional): The number of processes to use. Defaults to None.
    Returns:
        None
    """
    with Pool(processes=n_processes) as p:
        # This runs the function for each network run in parallel
        p.starmap(run_single_experiment, [(config, net, inputs, labels, train_generator) for net in range(config['num_runs'])])

    # Aggregate and save results
    aggregated_metrics, aggregated_singular_values, agg_outputs, agg_w1, agg_w2, agg_bias, log_dir = aggregate_results(config["group_name"],
                                                                                config["model_name"],
                                                                                config["num_epochs"],
                                                                                config['num_runs'],
                                                                                config["log_interval"])
    save_aggregated_results(aggregated_metrics,
                            aggregated_singular_values,
                            agg_outputs,
                            agg_w1,
                            agg_w2,
                            agg_bias,
                            log_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train a linear neural network.')
    parser.add_argument('--config',  type=str, default='shallow_net_config.yaml', help='Path to the configuration file.')
    parser.add_argument('--batch_size', type=int, help='Override the batch size in the configuration file.')
    parser.add_argument('--num_epochs', type=int, help='Override the number of epochs in the configuration file.')
    parser.add_argument('--step_size', type=float, help='Override the step size in the configuration file.')
    parser.add_argument('--scale', type=float, help='Override the init weight scale in the configuration file.')
    parser.add_argument('--choice_temp', type=float, help='Override the choice temperature in the configuration file.')
    parser.add_argument('--log_interval', type=int, help='Override the log interval in the configuration file.')
    parser.add_argument('--print_interval', type=int, help='Override the print interval')
    parser.add_argument('--include_bias_input', action='store_true', help='Override the include_bias_input in the configuration file.')
    parser.add_argument('--log_to_wandb', action='store_true', help='Override the log_to_wandb in the configuration file.')
    parser.add_argument('--wandb_project_name', type=str, help='Override the wandb_project_name in the configuration file.')
    parser.add_argument('--group_name', type=str, help='Override the group_name in the configuration file.')
    parser.add_argument('--log_locally', type=bool,default=True, help='Override the log_locally in the configuration file.if set, log to local directory')
    parser.add_argument('--num_runs', type=int, help='Override the number of runs in the configuration file.')
    parser.add_argument('--probabilistic', type=bool, help='Override if model makes deterministic or non-deterministic choices.')
    parser.add_argument('--grid_search', action='store_true', help='If set, run hyperparameter search over initial weight scale and choice temp.')
    parser.add_argument('--verbose', action='store_true', help='Override the verbose in the configuration file to output epoch-wise info.')
    parser.add_argument('--include_head_property', type=bool, help='Override the include_head_property in the configuration file.')
    parser.add_argument('--model_name', type=str, help='Override the model_name in the configuration file.')

    args = parser.parse_args()
    config = load_config(args.config)

    # Override config parameters if corresponding argument is provided
    for key, value in vars(args).items():
        if value is not None and key != "config":
            config[key] = value

    # Extracting subset of hyper-params from the config
    batch_size = config['batch_size']
    include_head_property = config['include_head_property']
    include_bias_input = config['include_bias_input']
    hyperparameter_grid = config.get("hyperparameter_grid", {})

    # initialise the dataset
    train_inputs, train_labels, training_generator = load_data(batch_size,
                                                                include_head_property,
                                                                include_bias_input)

    hyper_param_combinations = list(itertools.product(*hyperparameter_grid.values()))

    # if we get the parameter run the grid search
    if args.grid_search:
        for combination in hyper_param_combinations:
            start_time = time.time()
            print(f"Starting hyperparameter combination: {combination}")
            hyperparam_strs = []  # List to store "key=value" strings for each hyperparameter
            for idx, key in enumerate(hyperparameter_grid.keys()):
                config[key] = combination[idx]
                hyperparam_strs.append(f"{key}={combination[idx]}")

            # add the bias input to the hyperparam strings
            hyperparam_strs.append(f"bias_input={include_bias_input}")
            print(f"Running with hyperparameters: {hyperparam_strs}")
            # Join the hyperparam strings into a single string for group_name
            config['group_name'] = "_".join(hyperparam_strs)
            parallel_run_experiments(config, train_inputs, train_labels, training_generator)
            cell_time = round(time.time() - start_time, 4)
            print(f"Finished hyperparameter combination: {combination} in {cell_time} seconds")
    else:
        # change the group name to include the hyperparameters and the bias
        hyperparam_strs = []
        hyperparam_strs.append(f"choice_temp={config['choice_temp']}")
        hyperparam_strs.append(f"scale={config['scale']}")
        hyperparam_strs.append(f"bias_input={include_bias_input}")
        # Join the hyperparam strings into a single string for group_name
        config['group_name'] = "_".join(hyperparam_strs)
        parallel_run_experiments(config, train_inputs, train_labels, training_generator)

