import json
import logging
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from hydra.experimental.callback import Callback

from xac.utils.plotting import get_plot_name, plot_trajectory

log = logging.getLogger(__name__)


class ResultAggregator(Callback):
    """
    Collects each job's metrics.json and writes all_metrics.csv at sweep end.
    """

    def on_job_end(self, config, job_return, **kwargs) -> None:
        # # called after a child job finishes
        log.info(f"""Finished sweep number {str(config.hydra.sweep.subdir)}""")

    def on_multirun_end(self, config, **kwargs) -> None:
        # called once after the *last* job

        sweep_dir = Path(config.hydra.sweep.dir)
        metrics_files = sweep_dir.rglob("metrics.json")
        # Can also be set manually when debugging, e.g.: config.hydra.sweep.dir= 'multirun/2025-09-26/11-06-27'

        if config.meta.debug_mode:
            rows = [json.loads(p.read_text()) for p in metrics_files if p.exists()]
            results_df = pd.DataFrame(rows)

            path_agg = Path(config.hydra.sweep.dir) / "aggregated/"
            path_agg.mkdir(parents=True, exist_ok=True)

            results_df.to_csv(path_agg / "metrics_agg.csv", index=False)

            # For each blackbox function, plot aggregated results
            unique_bbfs = list(results_df["blackbox"].unique())

            # #Temporary: For HP tuning (additionally loop over)
            # unique_noise_levels= list(results_df['min_inferred_noise_level'].unique())
            # unique_prior_types= list(results_df['prior_type'].unique())
            # unique_run_all_attempts= list(results_df['run_all_attempts'].unique())

            # Iterate over each unique value and restrict df
            for temp_bbf in unique_bbfs:
                bbf_path = path_agg / f"""{str(temp_bbf)}/"""
                bbf_path.mkdir(parents=True, exist_ok=True)

                bbf_df = results_df[results_df["blackbox"] == temp_bbf]

                # Filter the df to all seeds that have been evaluated on all acquisition functions
                n_acq_fns = bbf_df["acquisition"].nunique()
                seed_counts = bbf_df.groupby("seed")["acquisition"].nunique()
                valid_seeds = seed_counts[seed_counts == n_acq_fns].index
                bbf_df = bbf_df[bbf_df["seed"].isin(valid_seeds)]

                metrics = [
                    "mae",
                    "mse",
                    "nlpd",
                    "nlpd_noisy",
                ]

                if (
                    config.application._target_
                    == "xac.applications.TabRepoBenchmarkApplication"
                ) and (
                    config.surrogate.fit_config._target_ != "xac.surrogates.NUTSConfig"
                ):
                    metrics.append("ce_loss")
                    metrics.append("ce_loss_noisy")
                    metrics.append("accuracy")
                    metrics.append("accuracy_noisy")

                # if (
                #     config.application._target_
                #     == "xac.applications.YahpoShapleyApplication"
                # ):
                #     metrics.append("mae_siq")

                if config.meta.time_ops:
                    metrics.append("hp_fit_duration")
                    metrics.append("acq_fun_duration")

                # Legacy code: Deactivated due to computational cost
                # ------------------------------------------------------------------
                # Local aggregations (seed level)
                # ------------------------------------------------------------------
                # if config.meta.plot_on_seed_level:
                #     unique_seeds = bbf_df["seed"].unique()
                #     for temp_seed in unique_seeds:
                #         bbf_df_temp_seed = bbf_df[bbf_df["seed"] == temp_seed]
                #         bbf_df_temp_seed = bbf_df_temp_seed.sort_values(
                #             by="acquisition"
                #         )

                #         for temp_metric in metrics:
                #             plot_trajectory(
                #                 [
                #                     torch.tensor(temp_dev)
                #                     for temp_dev in bbf_df_temp_seed[temp_metric]
                #                 ],
                #                 temp_metric,
                #                 list(bbf_df_temp_seed["acquisition"]),
                #                 path=bbf_path
                #                 / f"""seed_level/{temp_metric}_dev_seed_{temp_seed}.png""",
                #                 title=f"""{str(temp_bbf)} (seed {str(temp_seed)})""",
                #                 log_scale=False,
                #             )

                # ------------------------------------------------------------------
                # Global aggregations across seeds
                # ------------------------------------------------------------------

                # #TODO: Expand to all possible config entries
                # groupby_cols= ["acquisition", "min_inferred_noise_level", "prior_type", "run_all_attempts"]
                # bbf_df_agg= bbf_df.groupby(groupby_cols).agg(list)

                # bbf_df_index = list(bbf_df_agg.index)

                # if len(bbf_df_index) < 2:
                #     bbf_df_index_str= bbf_df_index
                # else:
                #     temp_varying_idx = [i for i, col in enumerate(list(zip(*bbf_df_index))) if len(set(col)) > 1]
                #     varying_groupby_cols = [groupby_cols[i] for i in temp_varying_idx]
                #     bbf_df_index_str = [temp_index_entry for temp_index_entry in sorted({tuple(r[i] for i in temp_varying_idx) for r in bbf_df_index})]
                #     bbf_df_index_str= [str(temp_index_entry) if len(temp_index_entry) > 1 else temp_index_entry[0] for temp_index_entry in bbf_df_index_str]

                # bbf_df_agg = bbf_df.groupby("acquisition").agg(list)
                # acq_fn_names = list(bbf_df_agg.index)

                bbf_df_agg = bbf_df.groupby("acquisition").agg(list)
                bbf_df_index_str = list(bbf_df_agg.index)

                plot_title = (
                    f"""{get_plot_name(temp_bbf)} ({str(valid_seeds.shape[0])} seeds)"""
                )

                for temp_metric_name in metrics:
                    # if (temp_metric_name == "acq_fun_duration"):
                    #     # or (temp_metric_name == "hp_fit_duration"):
                    #     temp_cat_list = [
                    #         torch.tensor(bbf_df_agg[temp_metric_name][temp_acq])
                    #         for temp_acq in acq_fn_names
                    #         if temp_acq != "Random"
                    #     ]

                    #     temp_cat_names = [
                    #         temp_acq
                    #         for temp_acq in acq_fn_names
                    #         if temp_acq != "Random"
                    #     ]

                    # else:
                    #     temp_cat_list = [
                    #         torch.tensor(bbf_df_agg[temp_metric_name][temp_acq])
                    #         for temp_acq in acq_fn_names
                    #     ]

                    #     temp_cat_names = acq_fn_names.copy()

                    shapleig_index_set = [
                        "EIG-EP",
                        "EIG-FP",
                        "KernelSHAP",
                        "LeverageSHAP",
                        "LeverageSHAP-GP",
                        "Permutation Sampling",
                        "Random",
                        "Regression MSR",
                        "SVARM",
                    ]

                    if set(bbf_df_agg.index.to_list()) == set(
                        [
                            "EIG-EP",
                            "EIG-FP",
                            "KernelSHAP",
                            "LeverageSHAP",
                            "LeverageSHAP-GP",
                            "Permutation Sampling",
                            "Random",
                            "Regression MSR",
                            "SVARM",
                        ]
                    ):
                        new_order = [
                            "Regression MSR",
                            "LeverageSHAP",
                            "KernelSHAP",
                            "Permutation Sampling",
                            "SVARM",
                            "LeverageSHAP-GP",
                            "EIG-EP",
                            "Random",
                            "EIG-FP",
                        ]
                        bbf_df_agg = bbf_df_agg.reindex(new_order)
                        # bbf_df_index_str = order
                        # df = df.set_index("method").loc[desired_order].reset_index()

                    # shapleig_index_set_sv_accuracy= ['Regression MSR', 'LeverageSHAP', 'KernelSHAP', 'Permutation Sampling', 'SVARM', 'EIG-FP']
                    # shapleig_index_set_sv_baselines= ['LeverageSHAP-GP', 'EIG-EP', 'Random', 'EIG-FP']

                    temp_cat_list = [
                        torch.tensor(bbf_df_agg.iloc[j][temp_metric_name])
                        for j in range(bbf_df_agg.shape[0])
                        if (
                            ~torch.isnan(
                                torch.tensor(bbf_df_agg.iloc[j][temp_metric_name])
                            )
                        ).any()
                    ]

                    temp_cat_names = [
                        bbf_df_agg.index[j]
                        for j in range(bbf_df_agg.shape[0])
                        if (
                            ~torch.isnan(
                                torch.tensor(bbf_df_agg.iloc[j][temp_metric_name])
                            )
                        ).any()
                    ]
                    # temp_cat_names = bbf_df_index_str.copy()

                    # if temp_metric_name == "mae" and (config.application._target_ == "xac.applications.YahpoShapleyApplication"):
                    #     temp_cat_list.append(torch.tensor(bbf_df_agg.iloc[0]['mae_siq'])) #0 is arbitrary. others are identical - but in different order :O
                    #     temp_cat_names.append("mae_siq")

                    for log_scaling in [True, False]:
                        if (not log_scaling) or (
                            log_scaling and (temp_metric_name in ["mae", "mse"])
                        ):
                            # Plot mean
                            # plot_trajectory(
                            #     main_data=[
                            #         temp_devs.mean(axis=0) for temp_devs in temp_cat_list
                            #     ],
                            #     granular_data=None,
                            #     plot_title= None, #plot_title,  # Map DS name to more formal name via dictionary
                            #     y_label= get_plot_name(temp_metric_name) + " (Mean)",  # Map to more formal names via dictionary (also mode if not in dictionary)
                            #     categories=temp_cat_names,  # Map to more formal names via dictionary
                            #     path=bbf_path
                            #     / str(temp_metric_name)
                            #     / f"""{temp_metric_name}_mean{("_log" if log_scaling else "")}.png""",
                            #     plot_std=False,
                            #     plot_individual_runs=False,
                            #     y_range_top=None,
                            #     y_range_bottom=None,
                            #     y_log_scale=log_scaling,
                            #     y_maximum= None,
                            #     legend=True,
                            #     legend_placement_bottom=False,
                            #     size_X0= int(results_df["initial_design_size"].iloc[0])
                            # )

                            # if temp_metric_name in ["mae", "mse"]:
                            #     #Plot mean with y maximum fixed
                            #     # for temp_quantile in [0.95, 0.97, 0.99, 0.999]:

                            #     #     temp_mean= torch.concat([temp_devs.mean(axis=0) for temp_devs in temp_cat_list])
                            #     #     temp_min= temp_mean.min()

                            #     #     if temp_cat_list[0].shape[0]==1:
                            #     #         y_minimum= temp_min.item()
                            #     #     else:
                            #     #         temp_std= torch.concat([torch.tensor(temp_devs.numpy().std(axis=0, ddof=1) / np.sqrt(temp_devs.shape[0])) for temp_devs in temp_cat_list])
                            #     #         temp_min_st= temp_std[temp_mean.argmin()]
                            #     #         y_minimum= (temp_min - 2*temp_min_st).item()

                            #     #     plot_trajectory(
                            #     #         main_data=[
                            #     #             temp_devs.mean(axis=0) for temp_devs in temp_cat_list
                            #     #         ],
                            #     #         granular_data=None,
                            #     #         plot_title=None, #plot_title,  # Map DS name to more formal name via dictionary
                            #     #         y_label= get_plot_name(temp_metric_name) + " (Mean)",  # Map to more formal names via dictionary (also mode if not in dictionary)
                            #     #         categories=temp_cat_names,  # Map to more formal names via dictionary
                            #     #         path=bbf_path
                            #     #         / str(temp_metric_name)
                            #     #         / f"""{temp_metric_name}_mean{("_log" if log_scaling else "")}_ycap{int(temp_quantile*100)}.png""",
                            #     #         plot_std=False,
                            #     #         plot_individual_runs=False,
                            #     #         y_range_top=None,
                            #     #         y_range_bottom=None,
                            #     #         y_log_scale=log_scaling,
                            #     #         y_maximum= torch.quantile(torch.concat([temp_devs for temp_devs in temp_cat_list]), temp_quantile).item(),
                            #     #         y_minimum= y_minimum, #torch.concat([temp_devs.mean(axis=0) - 2*(temp_devs.numpy().std(axis=0, ddof=1) / np.sqrt(temp_devs.shape[0])) for temp_devs in temp_cat_list]).min().item(),
                            #     #         legend=True,
                            #     #         legend_placement_bottom=False,
                            #     #         size_X0= int(results_df["initial_design_size"].iloc[0]))

                            #     # #Plot 3 best configs
                            #     # min_vals= [temp_devs.mean(axis=0).min().item() for temp_devs in temp_cat_list]
                            #     # top_3_idx= sorted(range(len(min_vals)), key=lambda i: min_vals[i])[:3]

                            #     # temp_data= [[temp_devs.mean(axis=0) for temp_devs in temp_cat_list][temp_idx] for temp_idx in top_3_idx]

                            #     # plot_trajectory(
                            #     #     main_data=temp_data,
                            #     #     granular_data=None,
                            #     #     plot_title=None, #plot_title, # Map DS name to more formal name via dictionary
                            #     #     y_label= get_plot_name(temp_metric_name) + " (Mean)",  # Map to more formal names via dictionary (also mode if not in dictionary)
                            #     #     categories=[temp_cat_names[temp_idx] for temp_idx in top_3_idx],  # Map to more formal names via dictionary
                            #     #     path=bbf_path
                            #     #     / str(temp_metric_name)
                            #     #     / f"""{temp_metric_name}_mean{("_log" if log_scaling else "")}_top3.png""",
                            #     #     plot_std=False,
                            #     #     plot_individual_runs=False,
                            #     #     y_range_top=None,
                            #     #     y_range_bottom=None,
                            #     #     y_log_scale=log_scaling,
                            #     #     y_maximum= None,
                            #     #     legend=True,
                            #     #     legend_placement_bottom=False,
                            #     #     size_X0= int(results_df["initial_design_size"].iloc[0])
                            #     # )

                            # # Plot mean and std
                            # for legend in [True, False]:
                            #     plot_trajectory(
                            #         main_data=[
                            #             temp_devs.mean(axis=0) for temp_devs in temp_cat_list
                            #         ],
                            #         granular_data=temp_cat_list,
                            #         plot_title=None, #plot_title,
                            #         y_label= get_plot_name(temp_metric_name) + " (Mean + SEM)",  # Map to more formal names via dictionary (also mode if not in dictionary)
                            #         categories=temp_cat_names,
                            #         path=bbf_path
                            #         / str(temp_metric_name)
                            #         / f"""{temp_metric_name}_mean_w_std{("_log" if log_scaling else "")}.png""",
                            #         plot_std=True,
                            #         plot_individual_runs=False,
                            #         y_range_top=None,
                            #         y_range_bottom=None,
                            #         y_log_scale=log_scaling,
                            #         legend=legend,
                            #         legend_placement_bottom=False,
                            #         size_X0= int(results_df["initial_design_size"].iloc[0])
                            #     )

                            # Plot mean and std (with y maximum fixed)
                            if temp_metric_name in ["mae", "mse"]:
                                # Plot mean with y maximum fixed
                                for temp_quantile in [0.95, 0.97, 0.99, 0.999]:

                                    for legend in [True, False]:

                                        shapleig_index_set_full = [
                                            "Regression MSR",
                                            "LeverageSHAP",
                                            "KernelSHAP",
                                            "Permutation Sampling",
                                            "SVARM",
                                            "LeverageSHAP-GP",
                                            "EIG-EP",
                                            "Random",
                                            "EIG-FP",
                                        ]
                                        shapleig_index_set_sv_accuracy = [
                                            "Regression MSR",
                                            "LeverageSHAP",
                                            "KernelSHAP",
                                            "Permutation Sampling",
                                            "SVARM",
                                            "EIG-FP",
                                        ]
                                        shapleig_index_set_sv_baselines = [
                                            "LeverageSHAP-GP",
                                            "EIG-EP",
                                            "Random",
                                            "EIG-FP",
                                        ]

                                        subset_idx = 0
                                        for temp_cat_names_subset in [
                                            shapleig_index_set_full,
                                            shapleig_index_set_sv_accuracy,
                                            shapleig_index_set_sv_baselines,
                                        ]:
                                            subset_idx += 1
                                            temp_cat_list_subset = [
                                                temp_cat_list[
                                                    temp_cat_names.index(cat_name)
                                                ]
                                                for cat_name in temp_cat_names_subset
                                                if cat_name in temp_cat_names
                                            ]

                                            temp_mean = torch.concat(
                                                [
                                                    temp_devs.mean(axis=0)
                                                    for temp_devs in temp_cat_list_subset
                                                ]
                                            )
                                            temp_min = temp_mean.min()

                                            if temp_cat_list_subset[0].shape[0] == 1:
                                                y_minimum = temp_min.item()
                                            else:
                                                temp_std = torch.concat(
                                                    [
                                                        torch.tensor(
                                                            temp_devs.numpy().std(
                                                                axis=0, ddof=1
                                                            )
                                                            / np.sqrt(
                                                                temp_devs.shape[0]
                                                            )
                                                        )
                                                        for temp_devs in temp_cat_list_subset
                                                    ]
                                                )
                                                temp_min_st = temp_std[
                                                    temp_mean.argmin()
                                                ]
                                                y_minimum = (
                                                    temp_min - 2 * temp_min_st
                                                ).item()

                                            plot_trajectory(
                                                main_data=[
                                                    temp_devs.mean(axis=0)
                                                    for temp_devs in temp_cat_list_subset
                                                ],
                                                granular_data=temp_cat_list_subset,
                                                plot_title=None,  # plot_title,  # Map DS name to more formal name via dictionary
                                                y_label=get_plot_name(temp_metric_name)
                                                + " (Mean + SEM)",  # Map to more formal names via dictionary (also mode if not in dictionary)
                                                categories=temp_cat_names_subset,  # Map to more formal names via dictionary
                                                path=bbf_path
                                                / str(temp_metric_name)
                                                / f"""{temp_metric_name}_mean_w_std{("_log" if log_scaling else "")}{("_leg" if legend else "")}_ycap{int(temp_quantile*100)}_{str(subset_idx)}.png""",
                                                plot_std=True,
                                                plot_individual_runs=False,
                                                y_range_top=None,
                                                y_range_bottom=None,
                                                y_log_scale=log_scaling,
                                                y_maximum=torch.quantile(
                                                    torch.concat(
                                                        [
                                                            temp_devs
                                                            for temp_devs in temp_cat_list_subset
                                                        ]
                                                    ),
                                                    temp_quantile,
                                                ).item(),
                                                y_minimum=y_minimum,  # torch.concat([temp_devs.mean(axis=0) - 2*(temp_devs.numpy().std(axis=0, ddof=1) / np.sqrt(temp_devs.shape[0])) for temp_devs in temp_cat_list]).min().item(),
                                                legend=legend,
                                                legend_placement_bottom=False,
                                                size_X0=int(
                                                    results_df[
                                                        "initial_design_size"
                                                    ].iloc[0]
                                                ),
                                            )

                            if temp_metric_name in [
                                "acq_fun_duration",
                                "hp_fit_duration",
                            ]:
                                plot_trajectory(
                                    main_data=[
                                        temp_cat_list[
                                            temp_cat_names.index("EIG-FP")
                                        ].mean(axis=0)
                                    ],
                                    granular_data=[
                                        temp_cat_list[temp_cat_names.index("EIG-FP")]
                                    ],
                                    plot_title=None,  # plot_title,  # Map DS name to more formal name via dictionary
                                    y_label=get_plot_name(
                                        temp_metric_name
                                    ),  # Map to more formal names via dictionary (also mode if not in dictionary)
                                    categories=[
                                        temp_cat_names.index("EIG-FP")
                                    ],  # Map to more formal names via dictionary
                                    path=bbf_path
                                    / str(temp_metric_name)
                                    / f"""{temp_metric_name}_mean_w_std_cust.png""",
                                    plot_std=True,
                                    plot_individual_runs=False,
                                    y_range_top=None,
                                    y_range_bottom=None,
                                    y_log_scale=False,
                                    legend=False,
                                    legend_placement_bottom=False,
                                    size_X0=int(
                                        results_df["initial_design_size"].iloc[0]
                                    ),
                                )

                            # # Plot mean and individual runs (scaled by mean)
                            # plot_trajectory(
                            #     main_data=[
                            #         temp_devs.mean(axis=0) for temp_devs in temp_cat_list
                            #     ],
                            #     granular_data=temp_cat_list,
                            #     plot_title=None, #plot_title,
                            #     y_label= get_plot_name(temp_metric_name) + " (Mean + Ind.)",  # Map to more formal names via dictionary (also mode if not in dictionary)
                            #     categories=temp_cat_names,
                            #     path=bbf_path
                            #     / str(temp_metric_name)
                            #     / f"""{temp_metric_name}_mean_w_ind_runs_mean_scaled{("_log" if log_scaling else "")}.png""",
                            #     plot_std=False,
                            #     plot_individual_runs=True,
                            #     y_mean_scale=True,
                            #     y_range_top=None,
                            #     y_range_bottom=None,
                            #     y_log_scale=log_scaling,
                            #     legend=True,
                            #     legend_placement_bottom=False,
                            #     size_X0= int(results_df["initial_design_size"].iloc[0])
                            # )

                            # # Plot mean and individual runs (scaled by mean)
                            # plot_trajectory(
                            #     main_data=[
                            #         temp_devs.mean(axis=0) for temp_devs in temp_cat_list
                            #     ],
                            #     granular_data=temp_cat_list,
                            #     plot_title=None, #plot_title,
                            #     y_label= get_plot_name(temp_metric_name) + " (Mean + Ind.)",  # Map to more formal names via dictionary (also mode if not in dictionary)
                            #     categories=temp_cat_names,
                            #     path=bbf_path
                            #     / str(temp_metric_name)
                            #     / f"""{temp_metric_name}_mean_w_ind_runs{("_log" if log_scaling else "")}.png""",
                            #     plot_std=False,
                            #     plot_individual_runs=True,
                            #     y_mean_scale=False,
                            #     y_range_top=None,
                            #     y_range_bottom=None,
                            #     y_log_scale=log_scaling,
                            #     legend=True,
                            #     legend_placement_bottom=False,
                            #     size_X0= int(results_df["initial_design_size"].iloc[0])
                            # )

                        else:
                            pass

        log.info(f"""Finished on_multirun_end.""")
