import os

import wandb
import numpy as np

def retrieve_runs(choice_temp:float, scale:float, project_path:str, metrics:list, group_name:str=None):
    """
    Retrieves and processes logged data from Weights & Biases (wandb) runs.

    Args:
        choice_temp (float): The choice temperature used in the runs.
        scale (float): The scale used in the runs.
        project_path (str): The path to the project containing the runs.
        metrics (list): A list of metrics to retrieve from the runs.
        group_name (str, optional): The name of the group of runs. Defaults to None.

    Returns:
        dict: A dictionary containing the retrieved and processed data.
    """

    api = wandb.Api()
    if not group_name:
        group_name = f"choice_temp={choice_temp}_scale={scale}"
    runs = api.runs(path=project_path, filters={"group": group_name})
    # Initialize variables
    collected_data = []

    for i, run in enumerate(runs):
        # Get the history using scan_history, which handles pagination
        if metrics is None:
            history_items = list(run.history(pandas=False, samples=100000))
        else:
            history_items = list(run.history(keys=metrics, pandas=False, samples=100000))

        # Convert history to a list of dicts
        history_dicts = list(history_items)

        # If we haven't collected the metric names yet, get them from the first run
        if metrics is None and history_dicts:
            metrics = history_dicts[0].keys()

        run_data = [
            [history.get(metric) for metric in metrics]
            for history in history_dicts
        ]
        collected_data.append(run_data)
        print(i)

    # Determine the number of epochs as the minimum length of epochs recorded among runs
    num_epochs = min(len(run_data) for run_data in collected_data if run_data)

    # Truncate data if runs have recorded a different number of epochs
    truncated_data = [run_data[:num_epochs] for run_data in collected_data]

    # Convert the truncated data to a NumPy array, with shape (n_networks, n_epochs, n_metrics)
    final_data_np = np.array(truncated_data)

    return {metric: final_data_np[:, :, idx] for idx, metric in enumerate(metrics)}


def save_metrics(model_name:str, choice_temp:float,scale:float, final_data_by_metric:dict):
    """
    Saves the metrics data to numpy files.

    Args:
        model_name (str): The name of the model.
        choice_temp (float): The choice temperature used in the metrics data.
        scale (float): The scale used in the metrics data.
        final_data_by_metric (dict): A dictionary containing the metrics data to be saved.

    Returns:
        None
    """

    model_name = f"{model_name}_temp{choice_temp}_scale{scale}"
    for metric, array in final_data_by_metric.items():
        # Set the filename based on the metric key
        filename = f"/{metric}.npy"
        results_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),
                                        '..', '..', '..', 'results', 'model_runs', 
                                        model_name))

        if not os.path.exists(results_folder):
            os.makedirs(results_folder)

        try:
            np.save(f'{results_folder}{filename}', array)
            print(f"Saved {metric} to {filename}")
        except Exception as e:
            print(f"Could not save {metric} due to: {e}")