import argparse
import glob
import multiprocessing
import pathlib
import warnings
from functools import partial

import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import bootstrap
from sklearn.multioutput import MultiOutputClassifier

import learned_planners.interp.plot  # noqa
from learned_planners import ON_CLUSTER
from learned_planners.interp.collect_dataset import DatasetStore
from learned_planners.interp.train_probes import TrainOn
from learned_planners.interp.utils import load_probe, predict

TOTAL_STEPS = 80


def process_file(file, probe, probe_info, hooks, multioutput, num_layers=3, repeats_per_step=3, is_drc=True):
    with warnings.catch_warnings(action="ignore"):
        ds = DatasetStore.load(file)
    try:
        if "direction" in probe_info.dataset_name:
            gts = getattr(ds, probe_info.dataset_name)(multioutput=multioutput)
        else:
            gts = getattr(ds, probe_info.dataset_name)()
    except AssertionError:
        print(f"Skipping {file} due to assertion error.")
        return [], []
    if probe_info.dataset_name in ["next_box"] and not ds.solved:
        return [], []

    cache = {k: ds.get_cache(key=k, only_env_steps=True) for k in ds.model_cache.keys() if any(hook in k for hook in hooks)}
    probe_preds = predict(
        cache,
        probe,
        probe_info,
        0,
        is_concatenated_cache=True,
        num_layers=num_layers,
        repeats_per_step=repeats_per_step,
        is_drc=is_drc,
    )
    correct_preds_in_advance = np.zeros(TOTAL_STEPS, dtype=int)
    total_preds_in_advance = np.zeros(TOTAL_STEPS, dtype=int)
    for idx in range(len(gts) - 1):
        move_idx = np.where(gts[idx] != gts[idx + 1])
        if len(move_idx[0]) == 0:
            continue
        move_idx = (move_idx[0][0], move_idx[1][0])
        move = gts[idx][move_idx]
        # a box could move through same square multiple times so we need to find
        # the first timestep where the box move could be predicted by the probe
        first_gt_move = np.where(gts[: idx + 1, *move_idx] == move)[0]
        first_correct_pred = np.where(probe_preds[: idx + 1, *move_idx] == move)[0]
        total_preds_in_advance[: (idx + 1 - first_gt_move[0])] += 1
        if len(first_correct_pred) == 0:
            continue
        first_correct_pred = first_correct_pred[0]
        correct_preds_in_advance[: (idx + 1 - first_correct_pred)] += 1

    return correct_preds_in_advance, total_preds_in_advance


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, default="drc33", choices=["drc11", "drc33", "resnet"])
    parser.add_argument("--dataset_path", type=str, default="/training/activations_dataset/valid_medium/0_think_step/")
    parser.add_argument("--probe_path", type=str, default="probes/best/boxes_future_direction_map_l-all.pkl")
    parser.add_argument("--probe_wandb_id", type=str, default="")
    parser.add_argument("--dataset_name", type=str, default="boxes_future_direction_map")
    parser.add_argument("--layer", type=int, default=-1)
    parser.add_argument("--num_levels", type=int, default=1000)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument(
        "--output_base_path", type=str, default="iclr_logs/future_accuracy/", help="Path to save plots and cache."
    )
    parser.add_argument("--no_cache", action="store_true", help="Don't use cache.")

    args = parser.parse_args()

    probe, grid_wise = load_probe(args.probe_path, args.probe_wandb_id)

    multioutput = isinstance(probe, MultiOutputClassifier)

    hooks = ["hook_relu"] if args.model_type == "resnet" else ["hook_h", "hook_c"]
    probe_info = TrainOn(layer=args.layer, grid_wise=grid_wise, dataset_name=args.dataset_name, hooks=hooks)

    dataset_path = pathlib.Path(args.dataset_path)
    repeats_per_step = 1
    is_drc = True
    if args.model_type == "drc11":
        num_layers, repeats_per_step = 1, 1
    elif args.model_type == "drc33":
        num_layers, repeats_per_step = 3, 3
    elif args.model_type == "resnet":
        num_layers, repeats_per_step = -1, 0
        is_drc = False
    else:
        raise NotImplementedError(f"Unknown model type {args.model_type}")
    layers = [args.layer] if args.layer >= 0 else range(num_layers)

    all_probe_preds = []
    all_gts = []
    # keys = [f"features_extractor.cell_list.{layer}.{hook}" for layer in layers for hook in probe_info.hooks]
    files = glob.glob(str(dataset_path / "*.pkl"))
    files = files[: args.num_levels] if args.num_levels > 0 else files

    if ON_CLUSTER:
        args.output_base_path = pathlib.Path("/training/") / args.output_base_path
    args.output_base_path = pathlib.Path(args.output_base_path)
    save_dir = args.output_base_path / args.dataset_name / args.probe_wandb_id
    save_dir.mkdir(parents=True, exist_ok=True)

    if (save_dir / "correct_preds_in_advance.npy").exists() and not args.no_cache:
        correct_preds_in_advance = np.load(save_dir / "correct_preds_in_advance.npy")
        total_preds_in_advance = np.load(save_dir / "total_preds_in_advance.npy")
    else:
        assert len(files) > 0, f"No pkl files found in {dataset_path}"

        if args.num_workers <= 1:
            results = [
                process_file(file, probe, probe_info, hooks, multioutput, num_layers, repeats_per_step, is_drc)
                for file in files
            ]
        else:
            with multiprocessing.Pool(args.num_workers) as pool:
                map_fn = partial(
                    process_file,
                    probe=probe,
                    probe_info=probe_info,
                    hooks=hooks,
                    multioutput=multioutput,
                    num_layers=num_layers,
                    repeats_per_step=repeats_per_step,
                    is_drc=is_drc,
                )
                results = list(pool.imap(map_fn, files))
        correct_preds_in_advance = np.stack([r[0] for r in results if len(r[0]) > 0])
        total_preds_in_advance = np.stack([r[1] for r in results if len(r[0]) > 0])
        np.save(save_dir / "correct_preds_in_advance.npy", correct_preds_in_advance)
        np.save(save_dir / "total_preds_in_advance.npy", total_preds_in_advance)

    def agg_fn(correct_preds_in_advance, total_preds_in_advance, axis=0):
        return 100 * np.sum(correct_preds_in_advance, axis=axis) / np.sum(total_preds_in_advance, axis=axis)

    rng = np.random.default_rng(seed=42)

    bootstrap_ci = bootstrap(
        (correct_preds_in_advance, total_preds_in_advance),
        statistic=agg_fn,
        random_state=rng,
        batch=100,
        n_resamples=1000,
        vectorized=False,
        paired=True,
        method="basic",
        axis=0,
    ).confidence_interval
    mean_value = agg_fn(correct_preds_in_advance, total_preds_in_advance)

    fig, ax = plt.subplots(1, 1, figsize=(2.0, 1.6))
    ax.plot(mean_value, label="Mean")
    ax.fill_between(range(TOTAL_STEPS), bootstrap_ci[0], bootstrap_ci[1], alpha=0.2, label="CI")
    ax.set_xlabel("Episode timesteps")
    ax.set_ylabel("Future Accuracy")
    ax.grid()
    plt.savefig(save_dir / "future_accuracy.pdf")
    print(f"Saved plot to {save_dir / 'future_accuracy.pdf'}")
