"""AUC score for short- and long-term channels for predicting actions in the future."""

# %%
import argparse
import copy
import glob
import multiprocessing
import pathlib
import pickle
import warnings
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import bootstrap
from sklearn.metrics import roc_auc_score

from learned_planners import IS_NOTEBOOK, ON_CLUSTER
from learned_planners.interp.channel_group import layer_groups
from learned_planners.interp.collect_dataset import DatasetStore
from learned_planners.interp.offset_fns import apply_inv_offset_lc
from learned_planners.interp.plot import apply_style

TOTAL_STEPS = 50

DIRECTION = {"up": 0, "down": 1, "left": 2, "right": 3}


def standardize_channel(channel_value, channel_dict):
    """Standardize the channel value based on its sign and index."""
    assert len(channel_value.shape) >= 2, f"Invalid channel value shape: {channel_value.shape}"
    channel_value = apply_inv_offset_lc(channel_value, channel_dict["layer"], channel_dict["idx"], last_dim_grid=True)
    sign = channel_dict["sign"]
    if isinstance(sign, str):
        assert sign in ["+", "-"], f"Invalid sign: {sign}"
        sign = 1 if sign == "+" else -1
    elif not isinstance(sign, int):
        raise ValueError(f"Invalid sign type: {type(sign)}")
    return channel_value * sign


def process_file(file, num_layers=3, repeats_per_step=3, is_drc=True, multioutput=False, hooks=["hook_h"]):
    with warnings.catch_warnings(action="ignore"):
        ds = DatasetStore.load(file)

    box_movements = ds.boxes_future_direction_map(multioutput=multioutput)
    agent_movements = ds.agents_future_direction_map(multioutput=multioutput)
    agent_exclusive = ds.agents_future_direction_map(multioutput=multioutput, agent_exclusive=True)

    key_format = "features_extractor.cell_list.{layer}.hook_h"
    cache = [ds.get_cache(key=key_format.format(layer=layer), only_env_steps=False) for layer in range(num_layers)]

    grouped_acts_by_steps = {}
    grouped_gt_by_steps = {}
    for grp_name, channel_dict_list in layer_groups.items():
        grouped_acts_by_steps[grp_name] = []
        grouped_gt_by_steps[grp_name] = []
        dir_idx = DIRECTION.get(grp_name.split(" ")[-1].lower(), None)
        if dir_idx is None:
            continue
        for channel_dict in channel_dict_list:
            if grp_name.startswith("B"):
                gt = box_movements
            elif grp_name.startswith("A"):
                if "exclusive" in channel_dict["description"]:
                    gt = agent_exclusive
                else:
                    gt = agent_movements
            else:
                raise NotImplementedError(f"Unknown group name {grp_name}")
            layer, channel = channel_dict["layer"], channel_dict["idx"]
            slc = slice(repeats_per_step - 1, None, repeats_per_step)
            acts: np.ndarray = standardize_channel(cache[layer][slc, channel], channel_dict)
            if "nfa" in channel_dict["description"].lower():
                slc = slice(repeats_per_step - 2, None, repeats_per_step)
            elif "mpa" in channel_dict["description"].lower():
                acts = acts.mean(axis=(-2, -1), keepdims=True)
                acts = np.broadcast_to(acts, gt.shape)

            acts_by_steps = [[] for _ in range(TOTAL_STEPS)]
            gt_by_steps = [[] for _ in range(TOTAL_STEPS)]
            for idx in range(len(gt) - 1):
                move_idx = np.where(gt[idx] != gt[idx + 1])
                if len(move_idx[0]) == 0:
                    continue
                move_idx = (move_idx[0][0], move_idx[1][0])
                move = gt[idx][move_idx]
                # a box could move through same square multiple times so we need to find
                # the first timestep where the box move could be predicted by the probe
                first_gt_move = np.where(gt[: idx + 1, *move_idx] == move)[0]
                last_step = min(idx, TOTAL_STEPS - 1)
                for i_idx, i in enumerate(range(last_step, first_gt_move[0] - 1, -1)):
                    acts_by_steps[i_idx].append(acts[i, *move_idx].item())
                    gt_by_steps[i_idx].append(gt[i, *move_idx].item() == dir_idx)
            grouped_acts_by_steps[grp_name].append(acts_by_steps)
            grouped_gt_by_steps[grp_name].append(gt_by_steps)

    return grouped_acts_by_steps, grouped_gt_by_steps


def flatten(lst):
    return [item for sublist in lst for item in sublist]


# %%
if IS_NOTEBOOK:

    class Args:
        model_type = "drc33"
        dataset_path = "/training/activations_dataset/valid_medium/0_think_step/"
        num_levels = 1000
        num_workers = 8
        output_base_path = "iclr_logs/future_accuracy_channels/"
        no_cache = False

    args = Args()
else:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, default="drc33", choices=["drc11", "drc33", "resnet"])
    parser.add_argument("--dataset_path", type=str, default="/training/activations_dataset/valid_medium/0_think_step/")
    parser.add_argument("--num_levels", type=int, default=1000)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument(
        "--output_base_path", type=str, default="iclr_logs/future_accuracy_channels/", help="Path to save plots and cache."
    )
    parser.add_argument("--no_cache", action="store_true", help="Don't use cache.")
    args = parser.parse_args()

assert args.model_type == "drc33", "Only drc33 is supported."

dataset_path = pathlib.Path(args.dataset_path)
repeats_per_step = 1
is_drc = True
if args.model_type == "drc11":
    num_layers, repeats_per_step = 1, 1
elif args.model_type == "drc33":
    num_layers, repeats_per_step = 3, 3
elif args.model_type == "resnet":
    num_layers, repeats_per_step = -1, 0
    is_drc = False
else:
    raise NotImplementedError(f"Unknown model type {args.model_type}")

all_probe_preds = []
all_gts = []
# keys = [f"features_extractor.cell_list.{layer}.{hook}" for layer in layers for hook in probe_info.hooks]
files = glob.glob(str(dataset_path / "*.pkl"))
files = files[: args.num_levels] if args.num_levels > 0 else files

if ON_CLUSTER:
    args.output_base_path = pathlib.Path("/training/") / args.output_base_path
args.output_base_path = pathlib.Path(args.output_base_path)
save_dir = args.output_base_path
save_dir.mkdir(parents=True, exist_ok=True)

if (save_dir / "results.pkl").exists() and not args.no_cache:
    with open(save_dir / "results.pkl", "rb") as f:
        results = pickle.load(f)
else:
    assert len(files) > 0, f"No pkl files found in {dataset_path}"

    if args.num_workers <= 1:
        results = [process_file(file, num_layers, repeats_per_step, is_drc) for file in files]
    else:
        with multiprocessing.Pool(args.num_workers) as pool:
            map_fn = partial(
                process_file,
                num_layers=num_layers,
                repeats_per_step=repeats_per_step,
                is_drc=is_drc,
            )
            results = list(pool.imap(map_fn, files))
        with open(save_dir / "results.pkl", "wb") as f:
            pickle.dump(results, f)
auc_scores = copy.deepcopy(layer_groups)

# %%

apply_style(figsize=(5.4, 1.5), font=6)

for agent in [False, True]:
    fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, layout="constrained")

    axes = axes.flatten()
    subplot_titles = list(auc_scores.keys())[:8]

    for fig_idx, key in enumerate(auc_scores.keys()):
        if not 4 * agent <= fig_idx < 4 * (agent + 1):
            continue
        ax = axes[fig_idx - 4 * agent]
        if len(results[0][0][key]) == 0:
            continue
        for idx in range(len(auc_scores[key])):
            # reset the lists to store AUC values and their bootstrap CIs per step
            auc_val_list = []
            ci_low_list = []
            ci_high_list = []
            for step in range(TOTAL_STEPS):
                y_true = np.concatenate([r[1][key][idx][step] for r in results])
                if not np.any(y_true):
                    auc_val_list.append(-1)
                    ci_low_list.append(-1)
                    ci_high_list.append(-1)
                    warnings.warn(f"All zeros in {key} {idx} {step}")
                    continue
                y_pred = np.concatenate([r[0][key][idx][step] for r in results])
                auc_val = 100 * roc_auc_score(y_true, y_pred)
                auc_val_list.append(auc_val)

                def stat(y_true_sample, y_pred_sample):
                    return 100 * roc_auc_score(y_true_sample, y_pred_sample)

                bs_result = bootstrap(
                    (y_true, y_pred),
                    statistic=stat,
                    n_resamples=10,
                    random_state=42,
                    confidence_level=0.95,
                    paired=True,
                    vectorized=False,
                    method="basic",
                )
                ci_low, ci_high = bs_result.confidence_interval
                ci_low_list.append(ci_low)
                ci_high_list.append(ci_high)
            auc_scores[key][idx]["auc_scores"] = np.array(auc_val_list)
            auc_scores[key][idx]["ci_low"] = np.array(ci_low_list)
            auc_scores[key][idx]["ci_high"] = np.array(ci_high_list)
            label = f"L{auc_scores[key][idx]['layer']}H{auc_scores[key][idx]['idx']}"
            long_term: bool = auc_scores[key][idx]["long-term"]
            x = np.arange(TOTAL_STEPS)
            ax.plot(
                x,
                auc_scores[key][idx]["auc_scores"],
                label=label,
                linewidth=1,
            )

            ax.fill_between(
                x,
                auc_scores[key][idx]["ci_low"],
                auc_scores[key][idx]["ci_high"],
                alpha=0.3,
            )
        ax.set_title(subplot_titles[fig_idx].replace("B", "Box").replace("A", "Agent"))
        ax.legend(handlelength=0.7)
    fig.supxlabel("Steps")
    fig.supylabel("AUC (%)")
    suffix = "agent" if agent else "box"
    if ON_CLUSTER:
        plt.savefig(f"/training/new_plots/future_auc_{suffix}.pdf")
    else:
        plt.savefig(f"../../../new_plots/future_auc_{suffix}.pdf")
    plt.show()

# %%

agent = False
all_lines = []

# Define color schemes for long-term and short-term predictions
# Blues for short-term, Reds for long-term
short_term_colors = plt.cm.Blues(np.linspace(0.4, 0.8, 8))  # Lighter blues
long_term_colors = plt.cm.Reds(np.linspace(0.4, 0.8, 8))  # Lighter reds

# Keep track of color indices for each category
short_term_idx = 0
long_term_idx = 0

for fig_idx, key in enumerate(auc_scores.keys()):
    if not 4 * agent <= fig_idx < 4 * (agent + 1):
        continue
    if len(results[0][0][key]) == 0:
        continue

    for idx in range(len(auc_scores[key])):
        auc_scores[key][idx]["auc_scores"] = []
        for step in range(TOTAL_STEPS):
            y_true = np.concatenate([r[1][key][idx][step] for r in results])
            if not np.any(y_true):
                auc_scores[key][idx]["auc_scores"].append(-1)
                warnings.warn(f"All zeros in {key} {idx} {step}")
                continue
            y_pred = np.concatenate([r[0][key][idx][step] for r in results])
            auc_scores[key][idx]["auc_scores"].append(100 * roc_auc_score(y_true, y_pred))

        auc_scores[key][idx]["auc_scores"] = np.array(auc_scores[key][idx]["auc_scores"])
        label = f"{key}-L{auc_scores[key][idx]['layer']}H{auc_scores[key][idx]['idx']}"
        long_term = auc_scores[key][idx]["long-term"]

        # Select color based on whether it's long-term or short-term
        if long_term:
            color = long_term_colors[long_term_idx % len(long_term_colors)]
            long_term_idx += 1
        else:
            color = short_term_colors[short_term_idx % len(short_term_colors)]
            short_term_idx += 1

        # line = ax.plot(
        #     np.arange(TOTAL_STEPS),
        #     auc_scores[key][idx]["auc_scores"],
        #     label=label,
        #     linestyle="--" if long_term else "-",
        #     # color=color,
        #     color="C0" if long_term else "C1",
        #     alpha=0.9,
        # )
        all_lines.append((auc_scores[key][idx]["auc_scores"], label, long_term))

# Create two legend columns - one for short-term and one for long-term
short_term_lines = [(line, label) for line, label, is_long in all_lines if not is_long]
long_term_lines = [(line, label) for line, label, is_long in all_lines if is_long]
# %%
assert agent is False

smaller = True
plt.style.use("default")
if smaller:
    apply_style(figsize=(1.82, 1.3), font=8)
else:
    apply_style(figsize=(2.7, 1.5), font=8)

fig, ax = plt.subplots(layout="constrained")

short_term_line = ax.plot([], [], color="C1", label="Short term")[0]
long_term_line = ax.plot([], [], color="C0", label="Long term")[0]

for line, label in long_term_lines:
    ax.plot(np.arange(TOTAL_STEPS), line, color="C0", alpha=0.7)
for line, label in short_term_lines:
    ax.plot(np.arange(TOTAL_STEPS), line, color="C1", alpha=0.7)

# for line, label, is_long in all_lines:
#     ax.plot(np.arange(TOTAL_STEPS), line, color="C0" if is_long else "C1")

ax.legend(handles=[short_term_line, long_term_line])

ax.set_xlabel("Steps")
ax.set_ylabel("AUC (%)")
if smaller:
    ax.set_yticks([60, 80, 100])
    ax.xaxis.labelpad = 1
    ax.yaxis.labelpad = 0
else:
    ax.set_yticks([50, 60, 70, 80, 90, 100])
ax.legend(loc="lower left", handlelength=0.8 if smaller else 1.2)
# ax.set_title("Box Movement Prediction AUC Scores")

# Adjust layout to prevent legend overlap
# plt.tight_layout()

suffix = "agent" if agent else "box"
if smaller:
    suffix += "_smaller"
if ON_CLUSTER:
    plt.savefig(f"/training/new_plots/future_auc_combined_{suffix}.pdf", bbox_inches="tight")
else:
    plt.savefig(f"../../../new_plots/future_auc_combined_{suffix}.pdf", bbox_inches="tight")
plt.show()

# %%
