import datetime

import submitit
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union
from src.neural_nets.non_linear_nets.save_modules import save_checkpoint, save_metrics_to_json, aggregate_metrics, save_aggregated_metrics_to_json

def train_cnns_parallel(
    config: Dict[str, Any], 
    executor: submitit.Executor, 
    executor_params: Dict[str, Any], 
    train_single_network_func: Callable[[Dict[str, Any], int, Path], Any]
    ):
    """
    Trains models in parallel using submitit.

    :param config: Configuration dictionary for the training.
    :param executor: An instance of submitit.Executor (or a derived class like AutoExecutor).
    :param executor_params: Dictionary containing parameters to be passed to the executor's update_parameters method.
    :param train_single_network_func: Function that will be used to train a single network. This function should accept
                                      config, network_id, and submit_it_log_dir as its arguments.
    :return: A list of results from each parallel job.
    """
    
    # Update executor parameters based on executor_params dict
    executor.update_parameters(**executor_params)

    # Submit jobs for training networks in parallel
    jobs = []
    with executor.batch():
        for network_id in range(config['n_network_runs']):
            job = executor.submit(train_single_network_func, config, network_id)
            jobs.append(job)

    # Wait for all jobs to finish and collect results
    results = [job.result() for job in jobs]
    print("finished training all networks")

    # save parameters needed for correct aggregation
    init_scale = config['init_scale']
    choice_temp = config['choice_temp']
    n_network_runs = config['n_network_runs']
    hidden_sizes = config['hidden_sizes']
    learning_rate = config['learning_rate']
    dataset_name = config['dataset_name']
    bias_init_str = config['bias_init']
    data_normalise = config['data_normalise']
    # orthogonalise data?
    orthogonalise = config['orthogonalise']
    model_name = f'cnn_relu_{dataset_name}_orth={orthogonalise}_size={hidden_sizes}_init_method={config["init_method"]}_learning_rate={learning_rate}'
    use_hidden_layer_bias = config['use_hidden_layer_bias']
    hyperparam_path = f'choice_temp={choice_temp}_bias_term={use_hidden_layer_bias}_bias_init={bias_init_str}_loss_fn={config["loss_function"]}'
    model_type = config['model_type']
    # You can aggregate the results here if needed, this should become its own function.
    aggregated_data = aggregate_metrics(model_name, model_type, data_normalise, hyperparam_path, n_network_runs)
    save_aggregated_metrics_to_json(aggregated_data, model_type, data_normalise, model_name, hyperparam_path, config)

    # For demonstration, returning raw results
    return results


def train_cnns(
    config: Dict[str, Any],  
    train_single_network_func: Callable[[Dict[str, Any], int, Path], Any]
    ):
    """
    Trains models in parallel using submitit.

    :param config: Configuration dictionary for the training.
    :param train_single_network_func: Function that will be used to train a single network. This function should accept
                                      config, network_id, and submit_it_log_dir as its arguments.
    :return: A list of results from each parallel job.
    """
    

    # Submit jobs for training networks in parallel
    jobs = []
    for network_id in range(config['n_network_runs']):
        train_single_network_func(config, network_id)

    # Wait for all jobs to finish and collect results
    results = [job.result() for job in jobs]
    print("finished training all networks")

    # save parameters needed for correct aggregation
    init_scale = config['init_scale']
    choice_temp = config['choice_temp']
    n_network_runs = config['n_network_runs']
    hidden_sizes = config['hidden_sizes']
    learning_rate = config['learning_rate']
    dataset_name = config['dataset_name']
    bias_init_str = config['bias_init']
    data_normalise = config['data_normalise']
    # orthogonalise data?
    orthogonalise = config['orthogonalise']
    model_name = f'cnn_relu_{dataset_name}_orth={orthogonalise}_size={hidden_sizes}_init_method={config["init_method"]}_learning_rate={learning_rate}'
    use_hidden_layer_bias = config['use_hidden_layer_bias']
    hyperparam_path = f'choice_temp={choice_temp}_bias_term={use_hidden_layer_bias}_bias_init={bias_init_str}_loss_fn={config["loss_function"]}'
    model_type = config['model_type']
    # You can aggregate the results here if needed, this should become its own function.
    aggregated_data = aggregate_metrics(model_name, model_type, data_normalise, hyperparam_path, n_network_runs)
    save_aggregated_metrics_to_json(aggregated_data, model_type, data_normalise, model_name, hyperparam_path, config)

    # For demonstration, returning raw results
    return results

def get_timestamp() -> str:
  """Return a date and time `str` timestamp."""
  return datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d-%H:%M")