# code to run my CNN on the hpc cluster with submitit
import argparse
from pathlib import Path
import datetime
import src

import submitit
import jax
from flax import linen as nn
import jax.numpy as jnp
import optax
import numpy as np

from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union

from src.neural_nets.linear_nets.training_modules import load_data
from src.neural_nets.model_data.load_mnist import load_hieararchy_mnist, load_imbalanced_mnist
from src.neural_nets.model_data.load_cifar import load_hieararchy_cifar
from src.neural_nets.non_linear_nets.cnn_model import CNN
from src.neural_nets.non_linear_nets.compute_metrics import compute_metrics
from src.neural_nets.non_linear_nets.train_state import create_train_state
from src.neural_nets.non_linear_nets.training_modules import train_step, squared_error_loss
from src.neural_nets.non_linear_nets.save_modules import save_checkpoint, save_metrics_to_json, aggregate_metrics, save_aggregated_metrics_to_json
from src.neural_nets.linear_nets.training_modules import load_config
from src.neural_nets.submit_cluster.submit import train_cnns_parallel, train_cnns, get_timestamp

def train_single_network(config, network_id):
    # comment to check that the merge will work
    # Dataset hyperparameters
    batch_size = config['batch_size']
    include_head_property = config['include_head_property']
    # orthogonalise data?
    orthogonalise = config['orthogonalise']
    dataset_name = config['dataset_name']
    data_normalise = config['data_normalise']

    # Load data
    if dataset_name == 'mnist':
        mean_labels, generator = load_hieararchy_mnist(batch_size=batch_size, 
                                                     include_headnode=include_head_property, 
                                                     flat=False, 
                                                     orthogonalise=orthogonalise,
                                                     normalise=data_normalise)
    elif dataset_name == 'cifar10':
        mean_labels, generator = load_hieararchy_cifar(batch_size=batch_size, 
                                                       include_headnode=include_head_property, 
                                                       flat=False, 
                                                       orthogonalise=orthogonalise,
                                                       normalise=data_normalise)
    elif dataset_name == 'binary_imbalanced_mnist':
        mean_labels, generator = load_imbalanced_mnist(batch_size=batch_size, 
                                                       normalise=data_normalise)

    
    # Generate a template image for initializing the model
    template_image = next(iter(generator))[0][0]
    template_image = jnp.array(template_image)
    template_image = template_image[None, :]  # Add a batch dimension

    # Other hyperparameters
    learning_rate = config['learning_rate']
    init_scale = config['init_scale']
    init_scale_conv = config['init_scale_conv']
    choice_temp = config['choice_temp']
    num_epochs = config['num_epochs']
    num_steps_per_epoch = len(generator) // batch_size
    hidden_sizes = config['hidden_sizes']
    model_name = f'cnn_relu_{dataset_name}_orth={orthogonalise}_size={hidden_sizes}_init_method={config["init_method"]}_learning_rate={learning_rate}'
    use_hidden_layer_bias = config['use_hidden_layer_bias']
    bias_init_str = config['bias_init']
    hyperparam_path = f'choice_temp={choice_temp}_bias_term={use_hidden_layer_bias}_bias_init={bias_init_str}_loss_fn={config["loss_function"]}'
    print_interval = config['print_interval']
    act_fn = nn.relu if config['activation_function'] == 'relu' else None
    model_type = config['model_type']
    num_eval_steps = config['num_eval_steps']

    if config['init_method'] == 'he':
        init_method = nn.initializers.he_normal
    elif config['init_method'] == 'xavier':
        init_method = nn.initializers.xavier_uniform
    else:
        init_method = nn.initializers.normal

    # Random seed and split
    init_rng = jax.random.PRNGKey(0)
    rng_key, sub_key = jax.random.split(init_rng)

    # Get a new random key for each new model
    rng_key, sub_key = jax.random.split(rng_key)

    # Instantiate model
    model = CNN(use_bias=use_hidden_layer_bias, init_scale=init_scale, convolution_init_scale=init_scale_conv,kernel_init=init_method, output_size=mean_labels.shape[0])
    
    # Initialize the train state
    state = create_train_state(model, sub_key, learning_rate=learning_rate, input_size=template_image.shape)

    # Initialize train_history dictionary
    metrics_history = {}
    for attr in dir(state.metrics):
        if not attr.startswith('__') and not callable(getattr(state.metrics, attr)):
            metrics_history[f'{attr}'] = []

    # add outputs to the metrics_history
    if config["save_outputs"]:
        metrics_history['outputs'] = []
        metrics_history['labels'] = []

    # compute metrics before training
    for step, batch in enumerate(generator):
        state, outputs = compute_metrics(state=state, batch=batch, key=sub_key, mean_labels=mean_labels)
        if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
            for metric,value in state.metrics.compute().items(): # compute metrics
                if f'{metric}' in metrics_history:
                    metrics_history[f'{metric}'].append(float(value)) # record metrics
            
            # record the outputs
            if config["save_outputs"]:
                metrics_history['outputs'].append(outputs.tolist())
                metrics_history['labels'].append(batch[1].tolist())
            state = state.replace(metrics=state.metrics.empty())
    print(f"Epoch 0 - Train loss: {round(metrics_history['loss'][-1], 4)}")

    # Training loop
    for epoch in range(num_epochs):
        for step, batch in enumerate(generator):
            state = train_step(state, batch, squared_error_loss)
            rng_key, sub_key = jax.random.split(rng_key)
            # state, outputs = compute_metrics(state=state, batch=batch, key=sub_key, mean_labels=mean_labels)

            if (step + 1) % num_eval_steps == 0:  # we want to evaluate the model every num_eval_steps
                state, outputs = compute_metrics(state=state, batch=batch, key=sub_key, mean_labels=mean_labels)
                for metric, value in state.metrics.compute().items():  # Compute metrics
                    if f'{metric}' in metrics_history:
                        metrics_history[f'{metric}'].append(float(value))  # Record metrics
                # record the outputs
                if config["save_outputs"]:
                    metrics_history['outputs'].append(outputs.tolist())
                    metrics_history['labels'].append(batch[1].tolist())

                state = state.replace(metrics=state.metrics.empty())  # Reset train_metrics for next training epoch

        if epoch % print_interval == 0:
            print(f"Epoch {epoch +1} - Train loss: {round(metrics_history['loss'][-1],4)}")
    
    # Save the metrics
    save_metrics_to_json(metrics_history, model_type, data_normalise, model_name, hyperparam_path, network_id)
    return metrics_history

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train an CNN model with bespoke parameters.')

    parser.add_argument('--config',
                        type=str,
                        default='cnn_hieararchy_config.yaml',
                        help='Path to the config file.')

    parser.add_argument('--batch_size',
                        type=int,
                        help='Override the batch size in the config file.')

    parser.add_argument('--include_head_property',
                        type=bool,
                        help='Override the include_head_property in the config file.')

    parser.add_argument('--num_epochs',
                        type=int,
                        help='Override the number of epochs in the config file.')

    parser.add_argument('--learning_rate',
                        type=float,
                        help='Override the learning_rate in the config file.')

    parser.add_argument('--scale',
                        type=float,
                        help='Override the init weight scale in the config file.')

    parser.add_argument('--choice_temp', 
                        type=float, 
                        help='Override the choice temperature in the config file.')

    parser.add_argument('--hidden_sizes',
                        type=tuple,
                        help='Override the hidden layer sizes in the config file.')

    parser.add_argument('--use_hidden_layer_bias',
                        action='store_true',
                        help='Override the use_hidden_layer_bias in the config file.')

    parser.add_argument('--print_interval',
                        type=int,
                        help='Override the print interval in the config file.')

    parser.add_argument('--activation_function',
                        type=str,
                        help='Override the activation function in the config file.')
    
    parser.add_argument('--orthogonalise',
                        action='store_true',
                        help='Override the orthogonalise in the config file.')

    args = parser.parse_args()
    config = load_config(args.config)
    # Override config parameters if corresponding argument is provided
    for key, value in vars(args).items():
        if value is not None and key != "config":
            config[key] = value
    # print(args.use_hidden_layer_bias, args.orthogonalise)
    # code for running on slurm

    # submit_it_log_dir = log_dir=Path(
    #   src.SCRATCH_DIR,
    #   f"hieararchical_{config['dataset_name']}_cnn",
    #   get_timestamp(), str(config['use_hidden_layer_bias'])
    # )
    
    # executor = submitit.AutoExecutor(folder=submit_it_log_dir)
    # # Executor parameters
    # executor_params = {
    #     'timeout_min': 60*3,
    #     'mem_gb': 8,
    #     'gpus_per_node': 1,
    #     'cpus_per_task': 4,
    #     'slurm_array_parallelism':256, 
    #     'slurm_partition': 'gpu',
    # }
    
    # results = train_cnns_parallel(config, executor, executor_params, train_single_network)
    results = train_cnns(config, train_single_network)
    # # # for debugging, also remember to un-jit relevant parts of code
    # train_single_network(config, 0) 