#  Copyright (c) 2025

import os
from pathlib import Path
from typing import List, Dict, Union, Any

import pandas as pd
import wandb


def get_wandb_panel(
    project_name: str,
    groups: List[Union[Dict[str, Any], str]],
    attribute_name: str = "episode_reward_mean",
    x_axis_name: str = "training_iteration",
    x_iterations: int = 400,
    filter: Dict = None,
    aggregate: bool = True,
    save_to_csv: bool = False,
    file_name: str = "",
    username: str = "user"
):
    api = wandb.Api()
    runs = api.runs(f"{username}/{project_name}", filters=filter)

    run_history_groups = (
        []
    )  # list of runs, each element is a dataframe with 2 columns ("training_iteration" and attribute_name) the number of rows is the number of samples available

    for run in runs:
        # Decide group of this run
        group = get_group(groups, run.config)

        history = run.history(
            keys=[attribute_name],
            samples=x_iterations,
            x_axis=x_axis_name,
        )
        if len(history):
            history = history.rename(columns={attribute_name: run.name})
            run_history_groups.append((run, history, group))

    unique_groups = list({str(group) for _, _, group in run_history_groups})
    unique_groups_dfs = [
        pd.DataFrame(range(x_iterations), columns=[x_axis_name]) for _ in unique_groups
    ]

    for i, run_history_group in enumerate(run_history_groups):
        run, history, group = run_history_group
        group_index = unique_groups.index(str(group))
        unique_groups_dfs[group_index] = pd.merge(
            unique_groups_dfs[group_index],
            history,
            how="outer",
            on=x_axis_name,
        )

    if save_to_csv:
        # Get the directory of the current file
        current_file_directory = os.path.dirname(os.path.abspath(__file__))
        for i, df in enumerate(unique_groups_dfs):
            name = Path(current_file_directory) / f"{file_name}_{unique_groups[i]}.csv"
            df.to_csv(name)

    for i, group_df in enumerate(unique_groups_dfs):
        if aggregate:
            temp_df = group_df.loc[:, group_df.columns.intersection([x_axis_name])]
            temp_df[["mean", "std"]] = group_df.drop(columns=[x_axis_name]).agg(
                ["mean", "std"], axis="columns"
            )
        else:
            temp_df = group_df
        temp_df.fillna({"std": 0}, inplace=True)
        unique_groups_dfs[i] = temp_df

    results = [
        (group, unique_groups_dfs[i].dropna()) for i, group in enumerate(unique_groups)
    ]

    return results


def get_group(groups: List[Union[Dict[str, Any], str]], run_config):
    group = {}
    for element in groups:
        if isinstance(element, str):
            group[element] = run_config[element]
        else:
            for key, value in element.items():
                group.update(get_group(element[key], run_config[key]))
    return group
