import argparse
import concurrent.futures
import os
import pathlib
import pickle
import subprocess
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from learned_planners.interp.collect_dataset import DatasetStore
from learned_planners.interp.train_probes import TrainOn
from learned_planners.interp.utils import predict

on_cluster = os.path.exists("/training")

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="/training/activations_dataset/hard/0_think_step/")
parser.add_argument("--num_levels", type=int, default=-1)
parser.add_argument("--save_path", type=str, default="probe_metrics.png")
args = parser.parse_args()

dataset_path = pathlib.Path(args.dataset_path)
num_levels = args.num_levels
save_path = args.save_path

if on_cluster:
    wandb_ids_and_infos = [
        ("dirnsbf3", TrainOn(layer=-1, dataset_name="agents_future_direction_map")),
        ("vb6474rg", TrainOn(layer=-1, dataset_name="boxes_future_direction_map")),
        ("42qs0bh1", TrainOn(layer=-1, dataset_name="next_target")),
        ("6e1w1bb6", TrainOn(layer=-1, dataset_name="next_box")),
    ]
    probe_files, probe_infos = [], []
    for wandb_id, probe_info in wandb_ids_and_infos:
        command = f"/training/findprobe.sh {wandb_id}"
        file_name = subprocess.run(command, shell=True, capture_output=True, text=True).stdout
        file_name = file_name.strip()
        probe_files.append(file_name)
        probe_infos.append(probe_info)
else:
    LP_DIR = pathlib.Path(__file__).parent.parent.parent
    probe_name_infos = [
        ("agents_future_direction_map_l_all.pkl", TrainOn(layer=-1, dataset_name="agents_future_direction_map")),
        ("boxes_future_direction_map_l_all.pkl", TrainOn(layer=-1, dataset_name="boxes_future_direction_map")),
        # ("boxes_future_direction_map_sparse_l_1.pkl", TrainOn(layer=1, dataset_name="boxes_future_direction_map")),
        ("next_target_l_all.pkl", TrainOn(layer=-1, dataset_name="next_target")),
        ("next_box_l_all.pkl", TrainOn(layer=-1, dataset_name="next_box")),
    ]
    probe_files = [LP_DIR / "probes" / file for file, _ in probe_name_infos]
    probe_infos = [info for _, info in probe_name_infos]

probes = []
for file_name in probe_files:
    with open(file_name, "rb") as f:
        probes.append(pickle.load(f))


def plot_metrics_across_steps(probes, probe_infos, dataset_path, num_levels=-1):
    level_files = glob(str(dataset_path / "*.pkl"))
    probe_wise_p = [[] for _ in probes]
    probe_wise_r = [[] for _ in probes]
    num_levels = len(level_files) if num_levels < 1 else num_levels
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        for ds_cache in tqdm(executor.map(DatasetStore.load, level_files), total=num_levels):
            cache = {k: ds_cache.get_cache(k) for k in ds_cache.model_cache.keys()}
            for pidx, (probe, probe_info) in enumerate(zip(probes, probe_infos)):
                labels = getattr(ds_cache, probe_info.dataset_name)()
                if len(labels) == 0:
                    continue
                preds = predict(cache, probe, probe_info, 0, internal_steps=False, is_concatenated_cache=True)

                preds = preds.squeeze().reshape(-1, 3, *labels.shape[1:])[: labels.shape[0]]
                labels = labels.unsqueeze(1).repeat(1, preds.shape[1], 1, 1).numpy()
                assert preds.shape == labels.shape, f"preds shape {preds.shape} != labels shape {labels.shape}"
                negative_label = labels.min()
                prec = preds[preds != negative_label] == labels[preds != negative_label]
                rec = preds[labels != negative_label] == labels[labels != negative_label]

                probe_wise_p[pidx].append(prec)
                probe_wise_r[pidx].append(rec)

    probe_wise_p_metric = [[] for _ in probes]
    probe_wise_r_metric = [[] for _ in probes]
    probe_wise_f1_metric = [[] for _ in probes]

    fig, axs = plt.subplots(1, len(probes), figsize=(5 * len(probes), 5))
    for pidx, (probe, probe_info) in enumerate(zip(probes, probe_infos)):
        for skip_steps in range(20):
            prec = np.concatenate([p_matrix[skip_steps:] for p_matrix in probe_wise_p[pidx]])
            rec = np.concatenate([r_matrix[skip_steps:] for r_matrix in probe_wise_r[pidx]])
            prec = prec.mean()
            rec = rec.mean()
            f1 = 2 * prec * rec / (prec + rec)
            probe_wise_p_metric[pidx].append(prec)
            probe_wise_r_metric[pidx].append(rec)
            probe_wise_f1_metric[pidx].append(f1)

        axs[pidx].plot(probe_wise_p_metric[pidx], label="Precision")
        axs[pidx].plot(probe_wise_r_metric[pidx], label="Recall")
        axs[pidx].plot(probe_wise_f1_metric[pidx], label="F1")
        axs[pidx].set_title(probe_info.dataset_name)
        axs[pidx].legend()
    plt.savefig(save_path)
    plt.show()


plot_metrics_across_steps(probes, probe_infos, dataset_path, num_levels=num_levels)
