import argparse
import os
import pathlib

import einops
import numpy as np
import torch as th
from cleanba.environments import BoxobanConfig, EnvpoolBoxobanConfig
from gym_sokoban.envs.sokoban_env import CHANGE_COORDINATES
from matplotlib import pyplot as plt
from scipy.stats import bootstrap
from sklearn.multioutput import MultiOutputClassifier

from learned_planners.interp.collect_dataset import DatasetStore
from learned_planners.interp.train_probes import TrainOn
from learned_planners.interp.utils import load_jax_model_to_torch, load_probe, play_level
from learned_planners.policies import download_policy_from_huggingface

on_cluster = os.path.exists("/training")
LP_DIR = pathlib.Path(__file__).parent.parent.parent

MODEL_PATH_IN_REPO = "drc33/bkynosqi/cp_2002944000/"  # DRC(3, 3) 2B checkpoint
MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)
if on_cluster:
    BOXOBAN_CACHE = pathlib.Path("/training/.sokoban_cache/")
else:
    BOXOBAN_CACHE = pathlib.Path(__file__).parent.parent.parent / "training/.sokoban_cache/"

parser = argparse.ArgumentParser()
parser.add_argument("--difficulty", type=str, default="medium")
parser.add_argument("--split", type=str, default="valid")
parser.add_argument("--thinking_steps", type=int, default=6)
parser.add_argument("--num_levels", type=int, default=1000)
parser.add_argument("--num_envs", type=int, default=128)
parser.add_argument("--probe_path", type=str, default="")
parser.add_argument("--probe_wandb_id", type=str, default="vb6474rg")
parser.add_argument("--dataset_name", type=str, default="boxes_future_direction_map")

args = parser.parse_args()
difficulty = args.difficulty
split = args.split
if split.lower() == "none" or split.lower() == "null" or not split:
    split = None
thinking_steps = args.thinking_steps
num_levels = args.num_levels
num_envs = args.num_envs

extra_kwargs = dict()
if on_cluster:
    cfg_cls = EnvpoolBoxobanConfig
    extra_kwargs = dict(load_sequentially=True)
else:
    cfg_cls = BoxobanConfig
    extra_kwargs = dict(asynchronous=False, tinyworld_obs=True)

boxo_cfg = cfg_cls(
    cache_path=BOXOBAN_CACHE,
    num_envs=num_envs,
    max_episode_steps=thinking_steps,
    min_episode_steps=thinking_steps,
    difficulty=difficulty,
    split=split,
    **extra_kwargs,
)
boxo_env = boxo_cfg.make()
cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)

probe, grid_wise = load_probe(args.probe_path, args.probe_wandb_id)
probe_info = TrainOn(grid_wise=grid_wise, dataset_name=args.dataset_name)
probes, probe_infos = [probe], [probe_info]
multioutput = isinstance(probe, MultiOutputClassifier)
if multioutput:
    raise NotImplementedError


def non_empty_squares_in_plan(plan):
    assert plan.ndim == (5 if multioutput else 4), f"Got {plan.shape}"
    if multioutput:
        raise NotImplementedError

    non_empty_squares = (plan >= 0).sum(axis=(-1, -2))  # -1 is empty square
    return non_empty_squares


def continuous_chains_in_plan(plan, boxes):
    """Total Continuous chain length starting from boxes"""
    assert plan.ndim == (5 if multioutput else 4), f"Got {plan.shape}"
    if multioutput:
        raise NotImplementedError

    total = np.zeros(plan.shape[:2])
    for batch in range(plan.shape[0]):
        for seq in range(plan.shape[1]):
            for box in boxes[batch]:
                total[batch, seq] += chain_length_from_box(plan[batch, seq], box)
    return total


def chain_length_from_box(plan, box):
    """Continuous chain length starting from box"""
    assert plan.ndim == 2, f"Got {plan.shape}"
    current_direction = plan[*box]
    chain_length = 0
    covered = set([10 * box[0] + box[1]])
    while current_direction != -1:
        chain_length += 1
        new_box = box + CHANGE_COORDINATES[current_direction]
        current_direction = plan[*new_box]
        if 10 * new_box[0] + new_box[1] in covered:
            break
        covered.add(10 * new_box[0] + new_box[1])
    return chain_length


def plan_quality(policy_th=policy_th, probes=probes, probe_infos=probe_infos):
    non_empty_squares = np.zeros((num_levels, thinking_steps * 3))
    continuous_chains = np.zeros((num_levels, thinking_steps * 3))

    device = th.device("cuda" if th.cuda.is_available() else "cpu")
    policy_th = policy_th.to(device)

    for i in range(int(np.ceil(num_levels / num_envs))):
        out = play_level(
            boxo_env,
            policy_th=policy_th,
            probes=probes,
            probe_train_ons=probe_infos,
            internal_steps=True,
            thinking_steps=thinking_steps,
            max_steps=thinking_steps,
        )
        curr_levels = min(num_levels - i * num_envs, num_envs)
        plan = einops.rearrange(out.probe_outs[0], "t i b h w -> b (t i) h w")[:curr_levels]
        boxes = np.stack([DatasetStore.get_box_position_per_step(out.obs[0, i].cpu()) for i in range(curr_levels)])
        non_empty_squares[i * num_envs : (i + 1) * num_envs] = non_empty_squares_in_plan(plan)
        continuous_chains[i * num_envs : (i + 1) * num_envs] = continuous_chains_in_plan(plan, boxes)

    return non_empty_squares, continuous_chains


save_dir = pathlib.Path("/training/iclr_logs/") if on_cluster else LP_DIR / "plot/interp/"
save_dir = save_dir / f"plan_quality/{args.dataset_name}/{difficulty}_{split}/"
save_dir.mkdir(parents=True, exist_ok=True)
if on_cluster and (save_dir / f"non_empty_squares_{num_levels}.npy").exists():
    print("Loading from cache")
    non_empty_squares = np.load(save_dir / f"non_empty_squares_{num_levels}.npy")
    continuous_chains = np.load(save_dir / f"continuous_chains_{num_levels}.npy")
else:
    non_empty_squares, continuous_chains = plan_quality()
    np.save(save_dir / f"non_empty_squares_{num_levels}.npy", non_empty_squares)
    np.save(save_dir / f"continuous_chains_{num_levels}.npy", continuous_chains)

rng = np.random.default_rng(seed=42)


def get_confidence_interval(data):
    return bootstrap(
        (data,), statistic=np.mean, random_state=rng, n_resamples=1000, vectorized=True, method="basic"
    ).confidence_interval


non_empty_squares_ci = get_confidence_interval(non_empty_squares)
continuous_chains_ci = get_confidence_interval(continuous_chains)

non_empty_squares = non_empty_squares.mean(axis=0)
continuous_chains = continuous_chains.mean(axis=0)

fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))

ax1.plot(non_empty_squares, label="Non-empty squares")
ax1.fill_between(range(len(non_empty_squares)), non_empty_squares_ci.low, non_empty_squares_ci.high, alpha=0.2)
ax1.plot(continuous_chains, label="Continuous chains")
ax1.fill_between(range(len(continuous_chains)), continuous_chains_ci.low, continuous_chains_ci.high, alpha=0.2)
ax1.set_title("Including internal steps")

non_empty_squares = non_empty_squares[2::3]
continuous_chains = continuous_chains[2::3]
ax2.plot(non_empty_squares, label="Non-empty squares")
ax2.fill_between(range(len(non_empty_squares)), non_empty_squares_ci.low[2::3], non_empty_squares_ci.high[2::3], alpha=0.2)
ax2.plot(continuous_chains, label="Continuous chains")
ax2.fill_between(range(len(continuous_chains)), continuous_chains_ci.low[2::3], continuous_chains_ci.high[2::3], alpha=0.2)
ax2.set_title("Excluding internal steps")
plt.legend()
plt.savefig(save_dir / "plots" / "plan_quality.png")
