# code to run my CNN on celeba the hpc cluster with submitit
import argparse
from pathlib import Path
import datetime
import src

import submitit
import jax
from flax import linen as nn
import jax.numpy as jnp
import optax
import numpy as np

from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union

from src.neural_nets.linear_nets.training_modules import load_data
from src.neural_nets.model_data.load_celeba import load_hieararchy_celeba
from src.neural_nets.non_linear_nets.cnn_model import CNN, CNN_mutliclass
from src.neural_nets.non_linear_nets.compute_metrics import compute_metrics_celeba
from src.neural_nets.non_linear_nets.train_state import create_train_state_celeba
from src.neural_nets.non_linear_nets.training_modules import train_step, squared_error_loss, bin_cross_entropy_loss
from src.neural_nets.non_linear_nets.save_modules import save_checkpoint, save_metrics_to_json, aggregate_metrics, save_aggregated_metrics_to_json
from src.neural_nets.linear_nets.training_modules import load_config
from src.neural_nets.submit_cluster.submit import train_cnns_parallel, get_timestamp


def train_single_network(config, network_id):
    # comment to check that the merge will work
    # Dataset hyperparameters
    batch_size = config['batch_size']
    include_head_property = config['include_head_property']
    # orthogonalise data?
    orthogonalise = config['orthogonalise']
    dataset_name = config['dataset_name']
    data_normalise = config['data_normalise']

    # Load data
    if dataset_name == 'celeba':
        def initialize_generator():
            return load_hieararchy_celeba(batch_size=batch_size, 
                                        include_headnode=include_head_property, 
                                        flat=False, 
                                        orthogonalise=orthogonalise,
                                        normalise=data_normalise)
        mean_labels, num_steps_per_epoch, generator = load_hieararchy_celeba(batch_size=batch_size, 
                                                     include_headnode=include_head_property, 
                                                     flat=False, 
                                                     orthogonalise=orthogonalise,
                                                     normalise=data_normalise)
    
    # Generate a template image for initializing the model
    template_image, _= next(generator)
    template_image = jnp.array(template_image[0])
    template_image = template_image[None, :]  # Add a batch dimension

    # Other hyperparameters
    learning_rate = config['learning_rate']
    init_scale = config['init_scale']
    init_scale_conv = config['init_scale_conv']
    choice_temp = config['choice_temp']
    num_epochs = config['num_epochs']
    # num_steps_per_epoch = len(generator) // batch_size
    hidden_sizes = config['hidden_sizes']
    model_name = f'cnn_relu_{dataset_name}_orth={orthogonalise}_size={hidden_sizes}_init_method={config["init_method"]}_learning_rate={learning_rate}'
    use_hidden_layer_bias = config['use_hidden_layer_bias']
    bias_init_str = config['bias_init']
    hyperparam_path = f'choice_temp={choice_temp}_bias_term={use_hidden_layer_bias}_bias_init={bias_init_str}_loss_fn={config["loss_function"]}'
    print_interval = config['print_interval']
    act_fn = nn.relu if config['activation_function'] == 'relu' else None
    model_type = config['model_type']
    num_eval_steps = config['num_eval_steps']

    if config['init_method'] == 'he':
        init_method = nn.initializers.he_normal
    elif config['init_method'] == 'xavier':
        init_method = nn.initializers.xavier_uniform
    else:
        init_method = nn.initializers.normal

    loss_function = squared_error_loss if config['loss_function'] == 'squared_error' else bin_cross_entropy_loss

    # Random seed and split
    init_rng = jax.random.PRNGKey(0)
    rng_key, sub_key = jax.random.split(init_rng)

    # Get a new random key for each new model
    rng_key, sub_key = jax.random.split(rng_key)

    # Instantiate model
    if config['loss_function'] == 'squared_error':
        model = CNN(use_bias=use_hidden_layer_bias, init_scale=init_scale, convolution_init_scale=init_scale_conv,kernel_init=init_method, output_size=mean_labels.shape[0])
    elif config['loss_function'] == 'sig_bin_crossentropy':
        model = CNN_mutliclass(use_bias=use_hidden_layer_bias, init_scale=init_scale, convolution_init_scale=init_scale_conv,kernel_init=init_method, output_size=mean_labels.shape[0])
    
    # Initialize the train state
    state = create_train_state_celeba(model, sub_key, learning_rate=learning_rate, input_size=template_image.shape, optimizer=config['optimizer'])

    # Initialize train_history dictionary
    metrics_history = {}
    for attr in dir(state.metrics):
        if not attr.startswith('__') and not callable(getattr(state.metrics, attr)):
            metrics_history[f'{attr}'] = []

    # compute metrics before training
    for step, batch in enumerate(generator):
        state = compute_metrics_celeba(state=state, batch=batch, key=sub_key, mean_labels=mean_labels, loss_fn=loss_function)
        if (step+1) % print_interval == 0:
            # current loss
            print(f"Epoch 0 - Train loss: {round(float(state.metrics.compute()['loss']), 4)}")
            # test_logits = state.apply_fn({'params': state.params}, batch[0])
            # print(test_logits[0])
            break

    for metric,value in state.metrics.compute().items(): # compute metrics
        if f'{metric}' in metrics_history:
            metrics_history[f'{metric}'].append(float(value)) # record metrics
            state = state.replace(metrics=state.metrics.empty())

    # Training loop
    for epoch in range(num_epochs):
        print(f"Epoch {epoch +1}")
        mean_labels, num_steps_per_epoch, generator = initialize_generator()  # Reinitialize the generator at the start of each epoch
        for step, batch in enumerate(generator):
            state = train_step(state, batch, loss_function)
            rng_key, sub_key = jax.random.split(rng_key)
            # state, outputs = compute_metrics(state=state, batch=batch, key=sub_key, mean_labels=mean_labels)
            if (step) % num_eval_steps == 0:  # we want to evaluate the model every num_eval_steps
                state = compute_metrics_celeba(state=state, batch=batch, key=sub_key, mean_labels=mean_labels, loss_fn=loss_function)
                for metric, value in state.metrics.compute().items():  # Compute metrics
                    if f'{metric}' in metrics_history:
                        metrics_history[f'{metric}'].append(float(value))  # Record metrics
                state = state.replace(metrics=state.metrics.empty())  # Reset train_metrics for next training epoch

            if (step+1) % print_interval == 0:
                # test_logits = state.apply_fn({'params': state.params}, batch[0])
                # print(test_logits[0])
                loss_value = metrics_history['loss'][-1]
                print(f"Step {step + 1}, epoch {epoch + 1} - Train loss: {round(float(loss_value), 4)}")
    
    # Save the metrics
    save_metrics_to_json(metrics_history, model_type, data_normalise, model_name, hyperparam_path, network_id)
    return metrics_history


def get_timestamp() -> str:
  """Return a date and time `str` timestamp."""
  return datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d-%H:%M")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train an CNN model with bespoke parameters.')

    parser.add_argument('--config',
                        type=str,
                        default='cnn_celeba_config.yaml',
                        help='Path to the config file.')

    parser.add_argument('--batch_size',
                        type=int,
                        help='Override the batch size in the config file.')

    parser.add_argument('--include_bias_input',
                        action='store_true',
                        help='Override the include_bias_input in the config file.')

    parser.add_argument('--include_head_property',
                        type=bool,
                        help='Override the include_head_property in the config file.')

    parser.add_argument('--num_epochs',
                        type=int,
                        help='Override the number of epochs in the config file.')

    parser.add_argument('--learning_rate',
                        type=float,
                        help='Override the learning_rate in the config file.')

    parser.add_argument('--scale',
                        type=float,
                        help='Override the init weight scale in the config file.')

    parser.add_argument('--choice_temp', 
                        type=float, 
                        help='Override the choice temperature in the config file.')

    parser.add_argument('--hidden_sizes',
                        type=tuple,
                        help='Override the hidden layer sizes in the config file.')

    parser.add_argument('--use_hidden_layer_bias',
                        type=bool,
                        help='Override the use_hidden_layer_bias in the config file.')

    parser.add_argument('--print_interval',
                        type=int,
                        help='Override the print interval in the config file.')

    parser.add_argument('--activation_function',
                        type=str,
                        help='Override the activation function in the config file.')

    args = parser.parse_args()
    config = load_config(args.config)

    submit_it_log_dir = log_dir=Path(
      src.SCRATCH_DIR,
      f"celeba_{config['dataset_name']}_cnn",
      get_timestamp(), str(config['use_hidden_layer_bias'])
    )
    
    executor = submitit.AutoExecutor(folder=submit_it_log_dir)
    # Executor parameters
    executor_params = {
        'timeout_min': 60*3,
        'mem_gb': 8,
        'gpus_per_node': 1,
        'cpus_per_task': 4,
        'slurm_array_parallelism':256, 
        'slurm_partition': 'gpu',
    }
    
    results = train_cnns_parallel(config, executor, executor_params, train_single_network)
    # # # for debugging, also remember to un-jit relevant parts of code, also rejit so it is not slow af
    # train_single_network(config, 0) 