import argparse
import multiprocessing
import pathlib
from functools import partial

import numpy as np
import pandas as pd
import torch as th
from gym_sokoban.envs.room_utils import CHANGE_COORDINATES
from sklearn.multioutput import MultiOutputClassifier
from tqdm import tqdm

from learned_planners import BOXOBAN_CACHE, MODEL_PATH_IN_REPO, ON_CLUSTER
from learned_planners.interp.channel_group import get_group_channels
from learned_planners.interp.collect_dataset import DatasetStore
from learned_planners.interp.offset_fns import offset_yx
from learned_planners.interp.plot import save_video as save_video_fn
from learned_planners.interp.train_probes import TrainOn
from learned_planners.interp.utils import get_boxoban_cfg, load_jax_model_to_torch, load_probe, play_level
from learned_planners.policies import download_policy_from_huggingface

MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)

MPA_CHANNEL_INFO = get_group_channels("mpa", return_dict=True)
MPA_CHANNELS = [d[0]["idx"] for d in MPA_CHANNEL_INFO]
MPA_SIGNS = [d[0]["sign"] for d in MPA_CHANNEL_INFO]
print("MPA_CHANNELS", MPA_CHANNELS)
print("MPA_SIGNS", MPA_SIGNS)

NFA_CHANNEL_INFO = get_group_channels("nfa", return_dict=True)
NFA_CHANNELS = [d[0]["idx"] for d in NFA_CHANNEL_INFO]
NFA_SIGNS = [d[0]["sign"] for d in NFA_CHANNEL_INFO]
print("NFA_CHANNELS", NFA_CHANNELS)
print("NFA_SIGNS", NFA_SIGNS)
NUM_LAYERS = 3
NUM_TICKS = 3
box_groups = get_group_channels("box", return_dict=True)
agent_groups = get_group_channels("agent", return_dict=True, exclude_nfa_mpa=True)

box_channels_grp_layer_wise = [[[] for _ in range(NUM_LAYERS)] for _ in range(len(box_groups))]
for box_group_idx, box_group in enumerate(box_groups):
    for channel_info in box_group:
        layer = channel_info["layer"]
        box_channels_grp_layer_wise[box_group_idx][layer].append((channel_info["idx"], channel_info["sign"]))

agent_channels_grp_layer_wise = [[[] for _ in range(NUM_LAYERS)] for _ in range(len(agent_groups))]
for agent_group_idx, agent_group in enumerate(agent_groups):
    for channel_info in agent_group:
        layer = channel_info["layer"]
        agent_channels_grp_layer_wise[agent_group_idx][layer].append((channel_info["idx"], channel_info["sign"]))


def patch_channels(
    cache,
    hook,
    new_direction_channels,
    old_direction_channels,
    new_channel_values,
    grid_i,
    grid_j,
    layer,
):
    old_grid_i, old_grid_j = offset_yx(grid_i, grid_j, old_direction_channels, layer)
    new_grid_i, new_grid_j = offset_yx(grid_i, grid_j, new_direction_channels, layer)
    cache[:, old_direction_channels, old_grid_i, old_grid_j] = 0
    cache[:, new_direction_channels, new_grid_i, new_grid_j] = th.tensor(
        new_channel_values, device=cache.device, dtype=cache.dtype
    )
    return cache


def patch_channels_agent_for_box(
    cache,
    hook,
    old_direction_idx,
    new_direction_idx,
    new_direction_channels,
    old_direction_channels,
    new_channel_values,
    grid_i,
    grid_j,
    layer,
):
    agent_grid_i, agent_grid_j = np.array([grid_i, grid_j]) + CHANGE_COORDINATES[flip_direction(old_direction_idx)]
    old_grid_i, old_grid_j = offset_yx(agent_grid_i, agent_grid_j, old_direction_channels, layer)
    cache[:, old_direction_channels, old_grid_i, old_grid_j] = 0

    agent_grid_i, agent_grid_j = np.array([grid_i, grid_j]) + CHANGE_COORDINATES[flip_direction(new_direction_idx)]
    new_grid_i, new_grid_j = offset_yx(agent_grid_i, agent_grid_j, new_direction_channels, layer)

    cache[:, new_direction_channels, new_grid_i, new_grid_j] = th.tensor(
        new_channel_values, device=cache.device, dtype=cache.dtype
    )
    return cache


def patch_channels_mpa(
    cache,
    hook,
    old_direction_idx,
    new_direction_idx,
):
    new_direction_channels = [MPA_CHANNELS[new_direction_idx]]
    other_direction_channels = [MPA_CHANNELS[i] for i in range(4) if i != new_direction_idx]
    # boundary neurons shouldn't be modified for this intervention
    cache[:, other_direction_channels, 2:-2, 2:-2] = 0
    new_channel_value = MPA_SIGNS[new_direction_idx]
    cache[:, new_direction_channels, 2:-2, 2:-2] = th.tensor(new_channel_value, device=cache.device, dtype=cache.dtype)
    return cache


def get_box_group_fwd_hooks(new_direction_idx, old_direction_idx, grid_i, grid_j, channel_type, is_box=True):
    channels_grp_layer_wise = box_channels_grp_layer_wise if is_box else agent_channels_grp_layer_wise
    hook_h_cs = ["hook_h"]

    if "mpa" in channel_type:
        fwd_hooks = [
            (
                f"features_extractor.cell_list.{layer}.{h_or_c_name}.{pos}.{int_pos}",
                partial(
                    patch_channels_mpa,
                    old_direction_idx=old_direction_idx,
                    new_direction_idx=new_direction_idx,
                ),
            )
            for pos in [0]
            for int_pos in [2]
            for layer in [2]
            for h_or_c_idx, h_or_c_name in enumerate(hook_h_cs)
        ]
        return fwd_hooks
    if "nfa" in channel_type:
        fwd_hooks = [
            (
                f"features_extractor.cell_list.{layer}.{h_or_c_name}.{pos}.{int_pos}",
                partial(
                    patch_channels,
                    new_direction_channels=[NFA_CHANNELS[new_direction_idx]],
                    old_direction_channels=[NFA_CHANNELS[old_direction_idx]],
                    new_channel_values=[NFA_SIGNS[new_direction_idx]],
                    grid_i=grid_i,
                    grid_j=grid_j,
                    layer=layer,
                ),
            )
            for pos in [0]
            for int_pos in [1]
            for layer in [2]
            for h_or_c_idx, h_or_c_name in enumerate(hook_h_cs)
        ]
        return fwd_hooks

    fwd_hooks = [
        (
            f"features_extractor.cell_list.{layer}.{h_or_c_name}.{pos}.{int_pos}",
            partial(
                patch_channels,
                new_direction_channels=list(map(lambda x: x[0], channels_grp_layer_wise[new_direction_idx][layer])),
                old_direction_channels=list(map(lambda x: x[0], channels_grp_layer_wise[old_direction_idx][layer])),
                new_channel_values=list(map(lambda x: int(x[1]), channels_grp_layer_wise[new_direction_idx][layer])),
                grid_i=grid_i,
                grid_j=grid_j,
                layer=layer,
            ),
        )
        for pos in [0]
        for int_pos in range(NUM_TICKS)
        for layer in range(NUM_LAYERS)
        for h_or_c_idx, h_or_c_name in enumerate(hook_h_cs)
    ]
    if is_box and channel_type == "box_agent":
        channels_grp_layer_wise = agent_channels_grp_layer_wise
        fwd_hooks += [
            (
                f"features_extractor.cell_list.{layer}.{h_or_c_name}.{pos}.{int_pos}",
                partial(
                    patch_channels_agent_for_box,
                    old_direction_idx=old_direction_idx,
                    new_direction_idx=new_direction_idx,
                    new_direction_channels=list(map(lambda x: x[0], channels_grp_layer_wise[new_direction_idx][layer])),
                    old_direction_channels=list(map(lambda x: x[0], channels_grp_layer_wise[old_direction_idx][layer])),
                    new_channel_values=list(map(lambda x: int(x[1]), channels_grp_layer_wise[new_direction_idx][layer])),
                    grid_i=grid_i,
                    grid_j=grid_j,
                    layer=layer,
                ),
            )
            for pos in [0]
            for int_pos in range(NUM_TICKS)
            for layer in range(NUM_LAYERS)
            for h_or_c_idx, h_or_c_name in enumerate(hook_h_cs)
        ]
    return fwd_hooks


def flip_direction(direction):
    # up -> down, down -> up, left -> right, right -> left
    return [1, 0, 3, 2][direction]


def ci_score_on_a_level(
    level_idx_tuple,
    boxo_cfg,
    save_video=False,
    intervention_steps=-1,
    channel_type="box",
    box_direction=True,
    logits=1,
):
    assert intervention_steps >= 1, f"{intervention_steps} should be >= 1"
    boxo_env = boxo_cfg.make()
    _, policy_th = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)
    lfi, li = level_idx_tuple
    reset_opts = {"level_file_idx": lfi, "level_idx": li}
    obs = boxo_env.reset(options=reset_opts)[0]
    obs = th.tensor(obs)

    out = play_level(
        boxo_env,
        policy_th=policy_th,
        reset_opts=reset_opts,
        thinking_steps=0,
        max_steps=80,  # 80 steps to compute the gt
    )

    ds_cache = DatasetStore(None, out.obs.squeeze(1), out.rewards, out.solved, out.acts, th.zeros(len(out.obs)), {})
    if box_direction:
        gt, timestep_map = ds_cache.boxes_future_direction_map(multioutput=False, return_timestep_map=True)
        gt, timestep_map = gt.numpy(), timestep_map.numpy()
    else:
        gt, timestep_map = ds_cache.agents_future_direction_map(multioutput=False, return_timestep_map=True)
        gt, timestep_map = gt.numpy(), timestep_map.numpy()
    first_timestep = 0
    gt = gt[first_timestep]
    timestep_map = timestep_map[first_timestep]

    idx_i, idx_j = np.where((gt != -1))

    all_ci_outputs = []

    is_box = "box" in channel_type

    for i, j in zip(idx_i, idx_j):
        # for box intervention, other boxes are considered as walls
        walls_on_side = ds_cache.get_wall_directions(i, j, box_is_wall=box_direction)
        next_to_wall = walls_on_side.any()
        direction_idx = gt[i, j]
        hook_time = timestep_map[i, j]
        assert hook_time >= 0, f"hook_time: {hook_time}"
        assert direction_idx > -1
        ci_directions = [dir_iter for dir_iter in range(4) if dir_iter != direction_idx]
        for ci_direction in ci_directions:
            fwd_hooks_fn = partial(get_box_group_fwd_hooks, channel_type=channel_type, is_box=is_box)
            fwd_hooks = fwd_hooks_fn(
                new_direction_idx=ci_direction,
                old_direction_idx=direction_idx,
                grid_i=i,
                grid_j=j,
            )
            if "nfa" in channel_type or "mpa" in channel_type:
                # NFA and MPA are not layer-wise
                hook_steps = [hook_time]
            else:
                hook_steps = range(max(0, hook_time - intervention_steps), hook_time + 1)

            steer_out = play_level(
                boxo_env,
                policy_th=policy_th,
                reset_opts=reset_opts,
                thinking_steps=0,
                fwd_hooks=fwd_hooks,
                # hook_steps=range(11),
                hook_steps=hook_steps,
                max_steps=80,
            )
            steer_ds_cache = DatasetStore(
                None,
                steer_out.obs.squeeze(1),
                steer_out.rewards,
                steer_out.solved,
                steer_out.acts,
                th.zeros(len(steer_out.obs)),
                {},
            )
            if box_direction:
                steer_gt, steer_timestep_map = steer_ds_cache.boxes_future_direction_map(
                    multioutput=False, return_timestep_map=True
                )
                steer_gt, steer_timestep_map = steer_gt.numpy(), steer_timestep_map.numpy()
            else:
                steer_gt, steer_timestep_map = steer_ds_cache.agents_future_direction_map(
                    multioutput=False, return_timestep_map=True
                )
                steer_gt, steer_timestep_map = steer_gt.numpy(), steer_timestep_map.numpy()
            steer_gt = steer_gt[first_timestep]
            steer_timestep_map = steer_timestep_map[first_timestep]
            if next_to_wall:
                box_dir_cond = walls_on_side[ci_direction] or walls_on_side[flip_direction(ci_direction)]
                agent_dir_cond = walls_on_side[ci_direction]
                if (box_direction and box_dir_cond) or (not box_direction and agent_dir_cond):
                    # CI should cause a wasted step
                    ci_success = steer_timestep_map[i, j] > hook_time
                else:
                    ci_success = steer_gt[i, j] == ci_direction
            else:
                ci_success = steer_gt[i, j] == ci_direction
            all_ci_outputs.append((lfi, li, i, j, next_to_wall, direction_idx, ci_direction, 1, ci_success))
            if save_video:
                name = f"automated_cis/{lfi}_{li}_idx_{i}_{j}_steered_{direction_idx}_{ci_direction}.mp4"
                save_video_fn(name, steer_out.obs.squeeze(1))
                print("success:", ci_success)

    all_ci_outputs = np.array(all_ci_outputs, dtype=int)
    if len(all_ci_outputs) > 0:
        mean_ci_success = all_ci_outputs[:, -1].mean()
        all_ci_outputs = all_ci_outputs.reshape(-1, 9)
    else:
        mean_ci_success = "NA"
    print("Level:", (lfi, li), "Mean ci success:", mean_ci_success, "Total:", len(all_ci_outputs))
    return all_ci_outputs


def get_random_levels(num_levels, total_files):
    if difficulty == "hard":
        total_levels = 3332
    else:
        total_levels = 1000 * total_files
    return [(ri // 1000, ri % 1000) for ri in np.random.randint(total_levels, size=num_levels)]


def get_probe_and_info(args):
    probe, grid_wise = load_probe(args.probe_path, args.probe_wandb_id)
    multioutput = isinstance(probe, MultiOutputClassifier)

    if multioutput:
        raise NotImplementedError

    probe_info = TrainOn(layer=args.layer, grid_wise=grid_wise, dataset_name=args.dataset_name, hooks=args.hooks.split(","))
    return probe, probe_info, multioutput


def get_coef(probe, probe_info, multioutput):
    if multioutput:
        coef = th.tensor([probe.estimators_[i].coef_.squeeze(0) for i in range(len(probe.estimators_))])
        intercept = th.tensor([probe.estimators_[i].intercept_ for i in range(len(probe.estimators_))])
    else:
        coef = th.tensor(probe.coef_)
        intercept = th.tensor(probe.intercept_)
        # zero out null direction as we only want to steer in the 4 directions
        coef[0] = 0
        intercept[0] = 0

    num_layers = 3 if probe_info.layer == -1 else 1
    n_segments = 2 * num_layers
    per_segment_neurons = coef.shape[1] // n_segments

    return coef, intercept, n_segments, per_segment_neurons


if __name__ == "__main__":
    # Example usage:
    # python ci_score_direction_channel.py --level 0 0 --num_levels 1 --num_workers 0 --save_video
    parser = argparse.ArgumentParser()
    parser.add_argument("--channel_type", type=str, default="box", choices=["box", "agent", "box_agent", "nfa", "mpa"])
    parser.add_argument("--level", type=int, nargs=2, default=(-1, -1), help="level file index, level index")
    parser.add_argument("--intervention_steps", type=int, default=-1, help="hook steps")
    parser.add_argument("--split", type=str, default="valid", help="split")
    parser.add_argument("--difficulty", type=str, default="medium", help="difficulty")
    parser.add_argument("--num_levels", type=int, default=10, help="number of levels to run")
    parser.add_argument("--save_video", action="store_true", help="save video")
    parser.add_argument("--num_workers", type=int, default=4, help="number of workers")
    parser.add_argument("--seed", type=int, default=0, help="seed")
    parser.add_argument("--output_base_path", type=str, default="ci_score/", help="Path to save plots and cache.")
    parser.add_argument("--logits", type=str, default="1", help="logits to try")
    args = parser.parse_args()

    box_direction = "box" in args.channel_type
    if box_direction:
        dataset_name = "boxes_future_direction_map"
    else:
        dataset_name = "agents_future_direction_map"

    split = args.split
    difficulty = args.difficulty
    logits = [int(logit) for logit in args.logits.split(",")]

    boxo_cfg = get_boxoban_cfg(
        difficulty=difficulty,
        split=split if args.split != "None" and args.split != "" else None,
        use_envpool=False,  # envpool doesn't support options on reset
    )

    if ON_CLUSTER and not args.output_base_path.startswith("/"):
        args.output_base_path = pathlib.Path("/training/") / args.output_base_path
    args.output_base_path = pathlib.Path(args.output_base_path) / dataset_name / f"{split}_{difficulty}"
    args.output_base_path.mkdir(parents=True, exist_ok=True)

    map_fn = partial(
        ci_score_on_a_level,
        boxo_cfg=boxo_cfg,
        logits=logits,
        save_video=args.save_video,
        intervention_steps=args.intervention_steps,
        box_direction=box_direction,
        channel_type=args.channel_type,
    )

    if all(v >= 0 for v in args.level):
        lfi, li = args.level
        print("Running on level", args.level)
        results = map_fn(args.level)
        file_name = f"ci_results_lfi_{lfi}_li_{li}.csv"
    else:
        np.random.seed(args.seed)

        level_files_dir = BOXOBAN_CACHE / "boxoban-levels-master" / difficulty / split
        num_files = len(list(level_files_dir.glob("*.txt")))

        lfi_li_list = get_random_levels(args.num_levels, num_files)
        if args.num_workers > 1:
            pool = multiprocessing.Pool(args.num_workers)
            results = list(tqdm(pool.imap(map_fn, lfi_li_list), total=len(lfi_li_list)))
            pool.close()
            pool.join()
        else:
            results = [map_fn(lfi_li) for lfi_li in tqdm(lfi_li_list)]
        results = np.concatenate(results)
        file_name = "ci_results.csv"
    df = pd.DataFrame(
        results,
        columns=["lfi", "li", "i", "j", "next_to_wall", "direction_idx", "ci_direction", "logit", "ci_success"],
    )
    split = "None" if split is None else split
    csv_file = args.output_base_path / file_name
    df.to_csv(csv_file, index=False)
    print("Final Mean ci success: ", df["ci_success"].mean(), "Total: ", len(results))
