import json
import os
import shutil
import flax
import yaml

def save_checkpoint(train_state, model_name, save_folder, net_id, epoch):
    """
    Save the checkpoint of the training params at the specified epoch to the given directory.

    Args:
        train_state: The training state object.
        model_name (str): The name of the model.
        save_folder (str): The folder to save the checkpoint.
        net_id (int): The index of the network run.
        epoch (int): The epoch number.
    """
    results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                    '..', '..', '..', 'results', 'model_runs_linear_flax', 
                                    model_name))
    log_dir = os.path.join(results_folder,save_folder, f'net_{net_id}')
    # Ensure the directory exists
    os.makedirs(os.path.dirname(log_dir), exist_ok=True)

    # Save model parameters
    params_path = os.path.join(log_dir, f"params_epoch_{epoch}.flax")
    with open(params_path, 'wb') as f:
        # Ensure that the parameters are unreplicated before saving
        params = flax.jax_utils.unreplicate(train_state.params)
        f.write(params.to_bytes())


def save_metrics_to_json(metrics_dict, model_type, data_type, model_name, save_folder, net_id):
    """
    Save the metrics dictionary to a JSON file at the specified path.

    Args:
        metrics_dict (dict): The dictionary containing the metrics.
        model_name (str): The name of the model.
        data_type (str): The type of data (e.g. normalised or not).
        save_folder(str): the specific subfolder from hyperparameter settings
        net_id (int): The index of the network run.
    """
    print("saving metrics to json file")
    results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                    '..', '..', '..', 'results', 'model_runs_cnn', 
                                    model_type, f'normalised={data_type}', model_name))
    log_dir = os.path.join(results_folder,save_folder, f'net_{net_id}.json')
    # Ensure the directory exists
    os.makedirs(os.path.dirname(log_dir), exist_ok=True)

    # Write the metrics to a JSON file
    with open(log_dir, 'w') as f:
        json.dump(metrics_dict, f, indent=4)


def aggregate_metrics(model_name, model_type, data_type, save_folder, n_network_runs):
    """
    Aggregate the metrics from multiple network runs for a given model.

    Args:
        model_name (str): The name of the model.
        save_folder (str): The folder to save the aggregated data.
        n_network_runs (int): The number of network runs.

    Returns:
        dict: The aggregated data containing metrics.
    """
    results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                    '..', '..', '..', 'results', 'model_runs_cnn', 
                                   model_type, f'normalised={data_type}', model_name))
    aggregated_data = {}

    for net_id in range(n_network_runs):
        file_path = os.path.join(results_folder, save_folder, f'net_{net_id}.json')

        # Read data from each file
        if os.path.exists(file_path):
            with open(file_path, 'r') as f:
                data = json.load(f)

            # Aggregate data
            for key, value in data.items():
                if key not in aggregated_data:
                    aggregated_data[key] = []
                aggregated_data[key].append(value)

            # Delete the file after aggregating its data
            os.remove(file_path)
        else:
            print(f"File {file_path} does not exist")

    return aggregated_data

def save_aggregated_metrics_to_json(aggregated_data, model_type, data_type, model_name, save_folder, config=None):
    """
    Save the aggregated metrics data to a JSON file at the specified location.

    Args:
        aggregated_data (dict): The aggregated metrics data.
        model_name (str): The name of the model.
        save_folder (str): The folder to save the aggregated results.
        config (dict): The configuration of the model.

    Returns:
        None
    """
    results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                    '..', '..', '..', 'results', 'model_runs_cnn', 
                                    model_type, f'normalised={data_type}', model_name))
    aggregated_file_path = os.path.join(results_folder, save_folder, 'aggregated_results.json')

    # Ensure the directory exists
    os.makedirs(os.path.dirname(aggregated_file_path), exist_ok=True)

    # Write the aggregated metrics to a JSON file
    with open(aggregated_file_path, 'w') as f:
        json.dump(aggregated_data, f, indent=4)
    
    # save the config as json file
    if config is not None:
        config_file_path = os.path.join(results_folder, save_folder, 'config.json')
        with open(config_file_path, 'w') as f:
            json.dump(config, f, indent=4)

    print(f"Aggregated metrics saved to {aggregated_file_path}")


if __name__ == "__main__":
    result = aggregate_metrics("cnn_relu_mnist_size=[32, 64, 96, 512]_init_method=he", "cnn", "choice_temp=0.2_scale=0.0015_bias_term=False", 10)
    save_aggregated_metrics_to_json(result, "cnn", "cnn_relu_mnist_size=[32, 64, 96, 512]_init_method=he", "choice_temp=0.2_scale=0.0015_bias_term=False")