# code to run my CNN on the hpc cluster with submitit
import argparse
from pathlib import Path
import datetime
import src

import submitit
import jax
from flax import linen as nn
import jax.numpy as jnp
import optax

from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union

from src.neural_nets.linear_nets.training_modules import load_data
from src.neural_nets.model_data.load_mnist import load_imbalanced_mnist
from src.neural_nets.non_linear_nets.cnn_model import CNN
from src.neural_nets.non_linear_nets.compute_metrics import compute_metrics_imbalance
from src.neural_nets.non_linear_nets.train_state import create_train_state
from src.neural_nets.non_linear_nets.training_modules import train_step, squared_error_loss, cross_entropy_loss
from src.neural_nets.non_linear_nets.save_modules import save_checkpoint, save_metrics_to_json, aggregate_metrics, save_aggregated_metrics_to_json
from src.neural_nets.linear_nets.training_modules import load_config
from src.neural_nets.submit_cluster.submit import train_cnns_parallel, get_timestamp

def train_single_network(config, network_id):
    # comment to check that the merge will work
    # Dataset hyperparameters
    batch_size = config['batch_size']
    include_head_property = config['include_head_property']
    # orthogonalise data?
    orthogonalise = config['orthogonalise']
    dataset_name = config['dataset_name']
    data_normalise = config['data_normalise']

    # Load data
    if dataset_name == 'binary_imbalanced_mnist':
        mean_labels, generator = load_imbalanced_mnist(batch_size=batch_size, 
                                                       normalise=data_normalise)

    
    # Generate a template image for initializing the model
    template_image = next(iter(generator))[0][0]
    template_image = jnp.array(template_image)
    template_image = template_image[None, :]  # Add a batch dimension

    # Other hyperparameters
    learning_rate = config['learning_rate']
    init_scale = config['init_scale']
    init_scale_conv = config['init_scale_conv']
    choice_temp = config['choice_temp']
    num_epochs = config['num_epochs']
    num_steps_per_epoch = len(generator) // batch_size
    hidden_sizes = config['hidden_sizes']
    model_name = f'cnn_relu_{dataset_name}_orth={orthogonalise}_size={hidden_sizes}_init_method={config["init_method"]}_learning_rate={learning_rate}'
    use_hidden_layer_bias = config['use_hidden_layer_bias']
    bias_init_str = config['bias_init']
    hyperparam_path = f'choice_temp={choice_temp}_bias_term={use_hidden_layer_bias}_bias_init={bias_init_str}_loss_fn={config["loss_function"]}'
    print_interval = config['print_interval']
    act_fn = nn.relu if config['activation_function'] == 'relu' else None
    model_type = config['model_type']
    num_eval_steps = config['num_eval_steps']
    if config['init_method'] == 'he':
        init_method = nn.initializers.he_normal
    elif config['init_method'] == 'xavier':
        init_method = nn.initializers.xavier_uniform
    else:
        init_method = nn.initializers.normal
    loss_function = squared_error_loss if config['loss_function'] == 'squared_error' else cross_entropy_loss

    # Random seed and split
    init_rng = jax.random.PRNGKey(0)
    rng_key, sub_key = jax.random.split(init_rng)

    # Get a new random key for each new model
    rng_key, sub_key = jax.random.split(rng_key)

    # Instantiate model
    model = CNN(use_bias=use_hidden_layer_bias, init_scale=init_scale, convolution_init_scale=init_scale_conv,kernel_init=init_method, output_size=mean_labels.shape[0])
    
    # Initialize the train state
    state = create_train_state(model, sub_key, learning_rate=learning_rate, input_size=template_image.shape)

    # Initialize train_history dictionary
    metrics_history = {}
    for attr in dir(state.metrics):
        if not attr.startswith('__') and not callable(getattr(state.metrics, attr)):
            metrics_history[f'{attr}'] = []

    # compute metrics before training
    for step, batch in enumerate(generator):
        state = compute_metrics_imbalance(state=state, batch=batch, key=sub_key, mean_labels=mean_labels, loss_fn=loss_function)
        if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
            for metric,value in state.metrics.compute().items(): # compute metrics
                if f'{metric}' in metrics_history:
                    metrics_history[f'{metric}'].append(float(value)) # record metrics
            state = state.replace(metrics=state.metrics.empty())
    print(f"Epoch 0 - Train loss: {round(metrics_history['loss'][-1], 4)}")

    # Training loop
    for epoch in range(num_epochs):
        for step, batch in enumerate(generator):
            state = train_step(state, batch, loss_function)
            rng_key, sub_key = jax.random.split(rng_key)
            state = compute_metrics_imbalance(state=state, batch=batch, key=sub_key, mean_labels=mean_labels, loss_fn=loss_function)

            if (step + 1) % num_eval_steps == 0:  # we want to evaluate the model every num_eval_steps
                for metric, value in state.metrics.compute().items():  # Compute metrics
                    if f'{metric}' in metrics_history:
                        metrics_history[f'{metric}'].append(float(value))  # Record metrics
                state = state.replace(metrics=state.metrics.empty())  # Reset train_metrics for next training epoch

        if epoch % print_interval == 0:
            print(f"Epoch {epoch +1} - Train loss: {round(metrics_history['loss'][-1],4)}")
    
    # Save the metrics
    save_metrics_to_json(metrics_history, model_type, data_normalise, model_name, hyperparam_path, network_id)
    return metrics_history


def get_timestamp() -> str:
  """Return a date and time `str` timestamp."""
  return datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d-%H:%M")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train an CNN model with bespoke parameters.')

    parser.add_argument('--config',
                        type=str,
                        default='cnn_binary_imbalance_config.yaml',
                        help='Path to the config file.')

    parser.add_argument('--batch_size',
                        type=int,
                        help='Override the batch size in the config file.')

    parser.add_argument('--include_bias_input',
                        action='store_true',
                        help='Override the include_bias_input in the config file.')

    parser.add_argument('--include_head_property',
                        type=bool,
                        help='Override the include_head_property in the config file.')

    parser.add_argument('--num_epochs',
                        type=int,
                        help='Override the number of epochs in the config file.')

    parser.add_argument('--learning_rate',
                        type=float,
                        help='Override the learning_rate in the config file.')

    parser.add_argument('--scale',
                        type=float,
                        help='Override the init weight scale in the config file.')

    parser.add_argument('--choice_temp', 
                        type=float, 
                        help='Override the choice temperature in the config file.')

    parser.add_argument('--hidden_sizes',
                        type=tuple,
                        help='Override the hidden layer sizes in the config file.')

    parser.add_argument('--use_hidden_layer_bias',
                        type=bool,
                        help='Override the use_hidden_layer_bias in the config file.')

    parser.add_argument('--print_interval',
                        type=int,
                        help='Override the print interval in the config file.')

    parser.add_argument('--activation_function',
                        type=str,
                        help='Override the activation function in the config file.')

    args = parser.parse_args()
    config = load_config(args.config)

    submit_it_log_dir = log_dir=Path(
      src.SCRATCH_DIR,
      f"{config['dataset_name']}_cnn",
      get_timestamp(), str(config['use_hidden_layer_bias'])
    )
    
    executor = submitit.AutoExecutor(folder=submit_it_log_dir)
    # Executor parameters
    executor_params = {
        'timeout_min': 60*3,
        'mem_gb': 8,
        'gpus_per_node': 1,
        'cpus_per_task': 4,
        'slurm_array_parallelism':256, 
        'slurm_partition': 'gpu',
    }
    
    results = train_cnns_parallel(config, executor, executor_params, train_single_network)
    # # for debugging, also remember to un-jit relevant parts of code
    # train_single_network(config, 0) 