import os
import json
import numpy as np

def load_metrics(model_name:str,
                choice_temp:float,
                scale:float,
                metric_list:list,
                bias_input:bool=None,
                folder:str = 'model_runs',
                loading_from_json:bool=False,
                bias_init:float="constant", 
                loss_fn:str=None):
    """
    Loads the metrics data from numpy files.

    Args:
        model_name (str): The name of the model.
        choice_temp (float): The choice temperature used in the metrics data.
        scale (float): The scale used in the metrics data.
        metric_list (list): A list of metrics to load.
        bias_input (bool, optional): Whether the model used bias input. Defaults to None.
        folder (str, optional): The folder to load the metrics from. Defaults to 'model_runs'.
        loading_from_json (bool, optional): Whether the metrics are being loaded from a JSON file. Defaults to False.

    Returns:
        dict: A dictionary containing the loaded metrics data.

    """

    # this is for the new saving functionality given by the larger model train code
    # this might ultimately replace the old functionality
    if loading_from_json:
        if loss_fn is not None:
            model_folder = f"choice_temp={choice_temp}_bias_term={bias_input}_bias_init={bias_init}_loss_fn={loss_fn}"
        else:
            model_folder = f"choice_temp={choice_temp}_bias_term={bias_input}_bias_init={bias_init}"
        results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                            '..', '..', '..', 'results', folder, 
                            model_name, model_folder))
        file_name = "/aggregated_results.json"
        # load the json file
        with open(f'{results_folder}{file_name}', 'r') as f:
            data = json.load(f)
        return {metric: data[metric] for metric in metric_list}
    
    # this is for the old saving functionality
    if bias_input is not None:
        model_folder = f"choice_temp={choice_temp}_scale={scale}_bias_input={bias_input}"
        results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                            '..', '..', '..', 'results', folder, 
                            model_name, model_folder))
    else: 
        model_name = f"{model_name}_temp{choice_temp}_scale{scale}"
        results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                    '..', '..', '..', 'results', folder, 
                                    model_name))

    metric_dict = {}
    for metric in metric_list:
        filename = f"/{metric}.npy"
        try:
            array = np.load(f'{results_folder}{filename}', allow_pickle=True)
            print(f"Loaded {metric} from {filename}")
            metric_dict[metric] = array
        except Exception:
            print(f"Could not load {metric} from {filename}")
            metric_dict[metric] = None
    return metric_dict