import argparse
import dataclasses
from pathlib import Path
from typing import Sequence

import moviepy.editor as mpy
import numpy as np
import torch as th
from cleanba.environments import BoxobanConfig

# %%

BASE_DIR = Path(__file__).parent.parent

parser = argparse.ArgumentParser(description="Parse file index and steps to think.")

# Add the arguments
parser.add_argument("--file_idx", type=int, required=False, help="The index of the file to process (e.g. 22).", default=22)  # xsokoban
parser.add_argument("--level_idxs", type=str, required=False, help="The indices of the level to play (e.g. '[0, 2, 31]')", default="[31]")
parser.add_argument("--steps_to_think", type=int, required=False, help="The number of steps to think (e.g. 0).", default=0)
parser.add_argument("--max_steps", type=int, required=False, default=1000)

args = parser.parse_args()
args.level_idxs = eval(args.level_idxs)
assert isinstance(args.level_idxs, list) and all(isinstance(i, int) for i in args.level_idxs), f"{args.level_idxs=} is not a list of ints"


# %%
boxo_cfg = BoxobanConfig(
    num_envs=1,
    min_episode_steps=args.max_steps,
    max_episode_steps=args.max_steps,
    tinyworld_obs=True,
    cache_path=BASE_DIR / "plot/alternative-levels/levels",
    seed=1234,
    difficulty="unfiltered",
    split="train",
)
env = boxo_cfg.make()

# %%
from learned_planners.interp.utils import load_jax_model_to_torch

# with (BASE_DIR / "drc_model" / "cfg.torch.json").open("r") as f:
#     model_cfg = farconf.from_dict(json.load(f), ConvLSTMPolicyConfig)

# model_cls, kwargs = model_cfg.policy_and_kwargs(env)
# model = model_cls(
#     observation_space=spaces.Box(0, 255, (3, 10, 10), dtype=np.uint8),
#     action_space=spaces.Discrete(4),
#     activation_fn=th.nn.ReLU,
#     lr_schedule=lambda _: 0.0,
#     normalize_images=True,
#     **kwargs)
# model.eval()

# model_path = BASE_DIR / "drc_model" / "model.pt"
# model.load_state_dict(th.load(model_path, weights_only=True))

model_path = BASE_DIR / "drc_model" / "model.pt"
model = load_jax_model_to_torch(model_path, boxo_cfg)

print("Loaded model from ", model_path)

# %%
def obs_to_torch(obs):
    out = th.as_tensor(obs).unsqueeze(0).permute((0, 3, 1, 2))
    return out


obs, info = env.reset(options=dict(level_file_idx=3, level_idx=1))

# %%


@dataclasses.dataclass
class EvalConfig:
    env: BoxobanConfig
    level_file_idx: int = 3  # dimitri & yorick
    level_idxs: Sequence[int] = range(1000),
    steps_to_think: list[int] = dataclasses.field(default_factory=lambda: [0])
    temperature: float = 0.0

    safeguard_max_episode_steps: int = 30000

    def run(self, get_action_fn, initialize_carry_fn) -> dict[str, float]:
        # assert isinstance(self.env, EnvpoolBoxobanConfig)
        max_steps = min(self.safeguard_max_episode_steps, self.env.max_episode_steps)
        episode_starts_no = th.zeros(1, dtype=th.bool)

        metrics = {}
        try:
            env = self.env.make()
            for steps_to_think in self.steps_to_think:
                all_episode_returns = []
                all_episode_lengths = []
                all_episode_successes = []
                all_obs = []
                all_acts = []
                all_rewards = []
                all_level_infos = []
                # envs = dataclasses.replace(self.env, seed=env_seed).make()
                for episode_i in self.level_idxs:
                    print(f"{steps_to_think=}, {episode_i=}")
                    try:
                        obs, level_infos = env.reset(options=dict(level_idx=episode_i, level_file_idx=self.level_file_idx))
                    except IndexError:
                        break

                    obs = obs_to_torch(obs)
                    carry = initialize_carry_fn(obs)

                    for think_step in range(steps_to_think):
                        _, carry = get_action_fn(obs, carry, episode_starts_no)

                    eps_done = np.zeros(1, dtype=np.bool_)
                    episode_success = np.zeros(1, dtype=np.bool_)
                    episode_returns = np.zeros(1, dtype=np.float64)
                    episode_lengths = np.zeros(1, dtype=np.int64)
                    episode_obs = np.zeros((max_steps + 1, *obs.shape), dtype=np.int64)
                    episode_acts = np.zeros((max_steps, 1), dtype=np.int64)
                    episode_rewards = np.zeros((max_steps, 1), dtype=np.float64)

                    episode_obs[0] = obs
                    i = 0
                    while not np.all(eps_done):
                        if i >= self.safeguard_max_episode_steps:
                            break

                        action, carry = get_action_fn(obs, carry, episode_starts_no)

                        cpu_action = action.item()
                        obs, rewards, terminated, truncated, infos = env.step(cpu_action)
                        obs = obs_to_torch(obs)

                        episode_returns += rewards  # type: ignore
                        episode_lengths += 1
                        episode_success |= terminated  # If episode terminates it's a success

                        episode_obs[i + 1, ...] = obs
                        episode_acts[i] = action
                        episode_rewards[i] = rewards

                        # Set as done the episodes which are done
                        eps_done |= truncated | terminated
                        i += 1

                    all_episode_returns.append(episode_returns)
                    all_episode_lengths.append(episode_lengths)
                    all_episode_successes.append(episode_success)

                    all_obs += [episode_obs[: episode_lengths[i], i] for i in range(1)]
                    all_acts += [episode_acts[: episode_lengths[i], i] for i in range(1)]
                    all_rewards += [episode_rewards[: episode_lengths[i], i] for i in range(1)]

                    all_level_infos.append(level_infos)

                all_episode_returns = np.stack(all_episode_returns)
                all_episode_lengths = np.stack(all_episode_lengths)
                all_episode_successes = np.stack(all_episode_successes)
                if isinstance(self.env, BoxobanConfig):
                    all_level_infos = {
                        k: np.stack([d[k] for d in all_level_infos])
                        for k in all_level_infos[0].keys()
                        if not k.startswith("_")
                    }
                else:
                    all_level_infos = {
                        k: np.stack([d[k] for d in all_level_infos]) for k in all_level_infos[0].keys() if "level" in k
                    }
                    total = set(zip(all_level_infos["level_file_idx"], all_level_infos["level_idx"]))
                    print(f"Total levels: {len(total)}")

                metrics.update(
                    {
                        f"{steps_to_think:02d}_episode_returns": float(np.mean(all_episode_returns)),
                        f"{steps_to_think:02d}_episode_lengths": float(np.mean(all_episode_lengths)),
                        f"{steps_to_think:02d}_episode_successes": float(np.mean(all_episode_successes)),
                        f"{steps_to_think:02d}_num_episodes": len(all_episode_returns),
                        f"{steps_to_think:02d}_all_episode_info": dict(
                            episode_returns=all_episode_returns,
                            episode_lengths=all_episode_lengths,
                            episode_successes=all_episode_successes,
                            episode_obs=all_obs,
                            episode_acts=all_acts,
                            episode_rewards=all_rewards,
                            level_infos=all_level_infos,
                        ),
                    }
                )
                print(f"Success rate for {steps_to_think} steps: {np.mean(all_episode_successes)}")
        finally:
            env.close()  # type: ignore
        return metrics


# %%

prediction_record = []
# %%


def initialize_carry(obs):
    carry = [tuple(th.zeros([1, 1, 32, obs.shape[2], obs.shape[3]]) for _ in range(2)) for _ in range(3)]
    return carry


cache = {}


def _save_fn(input, hook):
    cache[hook.name] = input
    return None

combined_probe, combined_intercepts = th.load(BASE_DIR / "drc_model" / "action_l2_probe.pt", weights_only=True)
print("Loaded probe for actions")

weight = th.tensor([1.2086, -0.0582, 0.2070])
bias = th.tensor([0.3337, -0.0921, -0.0632, -0.0539])
# %%




def get_action_fn(obs, carry, episode_starts):
    cache.clear()
    _, new_carry = model._recurrent_extract_features(obs, carry, episode_starts)
    # mlp_action, _value, _log_prob, new_carry = model(obs, carry, episode_starts, deterministic=True)

    activations = th.cat(
        [cache["features_extractor.cell_list.2.hook_h.0.2"], cache["features_extractor.cell_list.2.hook_c.0.2"]],
        dim=1,
    )
    # Action prediction using "action channels"
    # action_prediction = activations[:, [29, 8, 27, 3]]

    # Action prediction using probes
    action_prediction = th.einsum("nchw,oc->nohw", activations, combined_probe) + combined_intercepts[None, :, None, None]
    # mean-pooling
    num_action1 = action_prediction.mean((2, 3))
    # max-pooling
    num_action2 = action_prediction.max(dim=2, keepdim=False).values.max(dim=2, keepdim=False).values
    # count number of activated
    num_action3 = (action_prediction > 0).float().mean((2, 3))

    num_action = num_action1 * weight[0] + num_action2 * weight[1] + num_action3 * weight[2] + bias

    action = num_action.argmax(1)
    # prediction_record.append(action == mlp_action)
    return action, new_carry


with th.no_grad():
    with model.input_dependent_hooks_context(
        obs,
        fwd_hooks=[
            ("features_extractor.cell_list.2.hook_h.0", _save_fn),
            ("features_extractor.cell_list.2.hook_c.0", _save_fn),
        ],
        bwd_hooks=None,
    ):
        with model.hooks(
            fwd_hooks=[
                ("features_extractor.cell_list.2.hook_h.0.2", _save_fn),
                ("features_extractor.cell_list.2.hook_c.0.2", _save_fn),
            ]
        ):
            metrics = EvalConfig(
                boxo_cfg, level_idxs=args.level_idxs, level_file_idx=args.file_idx, steps_to_think=[args.steps_to_think]
            ).run(get_action_fn, initialize_carry)

            for level_idx, frames in zip(args.level_idxs, metrics[f"{args.steps_to_think:02d}_all_episode_info"]["episode_obs"]):
                frames = np.transpose(frames, (0, 2, 3, 1))
                assert frames.shape[-1] == 3

                scaling_factor = 20  # normal video is too small to play
                resized_frames = [np.repeat(np.repeat(frame, scaling_factor, axis=0), scaling_factor, axis=1) for frame in frames]


                clip = mpy.ImageSequenceClip(list(resized_frames), fps=15)
                clip.write_videofile(f"level_{args.file_idx:03d}_{level_idx:03d}.mp4")

# print("Prediction record: ", np.mean(prediction_record))
