import argparse
import json
import farconf
import dataclasses
from pathlib import Path
from typing import Sequence


import numpy as np
import pandas as pd
import torch as th
from gymnasium import spaces
import moviepy.editor as mpy

from learned_planners.environments import BoxobanConfig
from learned_planners.policies import ConvLSTMPolicyConfig

# %%

BASE_DIR = Path(__file__).parent.parent

parser = argparse.ArgumentParser(description="Parse file index and steps to think.")

# Add the arguments
parser.add_argument("--file_idx", type=int, required=False, help="The index of the file to process (e.g. 22).", default=22)  # xsokoban
parser.add_argument("--level_idxs", type=str, required=False, help="The indices of the level to play (e.g. '[0, 2, 31]')", default="[31]")
parser.add_argument("--steps_to_think", type=int, required=False, help="The number of steps to think (e.g. 0).", default=0)
parser.add_argument("--max_steps", type=int, required=False, default=1000)

args = parser.parse_args()
args.level_idxs = eval(args.level_idxs)
assert isinstance(args.level_idxs, list) and all(isinstance(i, int) for i in args.level_idxs), f"{args.level_idxs=} is not a list of ints"


# %%
boxo_cfg = BoxobanConfig(
    n_envs=1,
    n_envs_to_render=1,
    min_episode_steps=args.max_steps,
    max_episode_steps=args.max_steps,
    tinyworld_obs=True,
    cache_path=BASE_DIR / "plot/alternative-levels/levels",
    seed=1234,
    difficulty="unfiltered",
    split="train",
)
env = boxo_cfg.make()

# %%
with (BASE_DIR / "drc_model" / "cfg.torch.json").open("r") as f:
    model_cfg = farconf.from_dict(json.load(f), ConvLSTMPolicyConfig)

model_cls, kwargs = model_cfg.policy_and_kwargs(env)
model = model_cls(
    observation_space=spaces.Box(0, 255, (3, 10, 10), dtype=np.uint8),
    action_space=spaces.Discrete(4),
    activation_fn=th.nn.ReLU,
    lr_schedule=lambda _: 0.0,
    normalize_images=True,
    **kwargs)
model.eval()

model_path = BASE_DIR / "drc_model" / "model.pt"
model.load_state_dict(th.load(model_path, weights_only=True))
print("Loaded model from ", model_path)

# %%
def obs_to_torch(obs):
    out = th.as_tensor(obs).unsqueeze(0).permute((0, 3, 1, 2))
    return out


obs, info = env.reset(options=dict(level_file_idx=3, level_idx=1))

# %%


@dataclasses.dataclass
class EvalConfig:
    env: BoxobanConfig
    level_file_idx: int = 3  # dimitri & yorick
    level_idxs: Sequence[int] = range(1000),
    steps_to_think: list[int] = dataclasses.field(default_factory=lambda: [0])
    temperature: float = 0.0

    safeguard_max_episode_steps: int = 30000

    def run(self, get_action_fn, initialize_carry_fn) -> dict[str, float]:
        # assert isinstance(self.env, EnvpoolBoxobanConfig)
        max_steps = min(self.safeguard_max_episode_steps, self.env.max_episode_steps)
        episode_starts_no = th.zeros(1, dtype=th.bool)

        metrics = {}
        try:
            env = self.env.make()
            for steps_to_think in self.steps_to_think:
                all_episode_returns = []
                all_episode_lengths = []
                all_episode_successes = []
                all_obs = []
                all_acts = []
                all_rewards = []
                all_level_infos = []
                # envs = dataclasses.replace(self.env, seed=env_seed).make()
                for episode_i in self.level_idxs:
                    print(f"{steps_to_think=}, {episode_i=}")
                    try:
                        obs, level_infos = env.reset(options=dict(level_idx=episode_i, level_file_idx=self.level_file_idx))
                    except IndexError:
                        break

                    obs = obs_to_torch(obs)
                    carry = initialize_carry_fn(obs)

                    for think_step in range(steps_to_think):
                        _, carry = get_action_fn(obs, carry, episode_starts_no)

                    eps_done = np.zeros(1, dtype=np.bool_)
                    episode_success = np.zeros(1, dtype=np.bool_)
                    episode_returns = np.zeros(1, dtype=np.float64)
                    episode_lengths = np.zeros(1, dtype=np.int64)
                    episode_obs = np.zeros((max_steps + 1, *obs.shape), dtype=np.int64)
                    episode_acts = np.zeros((max_steps, 1), dtype=np.int64)
                    episode_rewards = np.zeros((max_steps, 1), dtype=np.float64)

                    episode_obs[0] = obs
                    i = 0
                    while not np.all(eps_done):
                        if i >= self.safeguard_max_episode_steps:
                            break

                        action, carry = get_action_fn(obs, carry, episode_starts_no)

                        cpu_action = action.item()
                        obs, rewards, terminated, truncated, infos = env.step(cpu_action)
                        obs = obs_to_torch(obs)

                        episode_returns += rewards  # type: ignore
                        episode_lengths += 1
                        episode_success |= terminated  # If episode terminates it's a success

                        episode_obs[i + 1, ...] = obs
                        episode_acts[i] = action
                        episode_rewards[i] = rewards

                        # Set as done the episodes which are done
                        eps_done |= truncated | terminated
                        i += 1

                    all_episode_returns.append(episode_returns)
                    all_episode_lengths.append(episode_lengths)
                    all_episode_successes.append(episode_success)

                    all_obs += [episode_obs[: episode_lengths[i], i] for i in range(1)]
                    all_acts += [episode_acts[: episode_lengths[i], i] for i in range(1)]
                    all_rewards += [episode_rewards[: episode_lengths[i], i] for i in range(1)]

                    all_level_infos.append(level_infos)

                all_episode_returns = np.stack(all_episode_returns)
                all_episode_lengths = np.stack(all_episode_lengths)
                all_episode_successes = np.stack(all_episode_successes)
                if isinstance(self.env, BoxobanConfig):
                    all_level_infos = {
                        k: np.stack([d[k] for d in all_level_infos])
                        for k in all_level_infos[0].keys()
                        if not k.startswith("_")
                    }
                else:
                    all_level_infos = {
                        k: np.stack([d[k] for d in all_level_infos]) for k in all_level_infos[0].keys() if "level" in k
                    }
                    total = set(zip(all_level_infos["level_file_idx"], all_level_infos["level_idx"]))
                    print(f"Total levels: {len(total)}")

                metrics.update(
                    {
                        f"{steps_to_think:02d}_episode_returns": float(np.mean(all_episode_returns)),
                        f"{steps_to_think:02d}_episode_lengths": float(np.mean(all_episode_lengths)),
                        f"{steps_to_think:02d}_episode_successes": float(np.mean(all_episode_successes)),
                        f"{steps_to_think:02d}_num_episodes": len(all_episode_returns),
                        f"{steps_to_think:02d}_all_episode_info": dict(
                            episode_returns=all_episode_returns,
                            episode_lengths=all_episode_lengths,
                            episode_successes=all_episode_successes,
                            episode_obs=all_obs,
                            episode_acts=all_acts,
                            episode_rewards=all_rewards,
                            level_infos=all_level_infos,
                        ),
                    }
                )
                print(f"Success rate for {steps_to_think} steps: {np.mean(all_episode_successes)}")
        finally:
            env.close()  # type: ignore
        return metrics


# %%

prediction_record = []
# %%


def initialize_carry(obs):
    carry = [tuple(th.zeros([1, 1, 32, obs.shape[2], obs.shape[3]]) for _ in range(2)) for _ in range(3)]
    return carry


cache = {}


def _save_fn(input, hook):
    cache[hook.name] = input
    return None

combined_probe, combined_intercepts = th.load(BASE_DIR / "drc_model" / "action_l2_probe.pt", weights_only=True)
print("Loaded probe for actions")

weight = th.tensor([1.2086, -0.0582, 0.2070])
bias = th.tensor([0.3337, -0.0921, -0.0632, -0.0539])
# %%




def get_action_fn(obs, carry, episode_starts):
    cache.clear()
    _, new_carry = model._recurrent_extract_features(obs, carry, episode_starts)
    # mlp_action, _value, _log_prob, new_carry = model(obs, carry, episode_starts, deterministic=True)

    activations = th.cat(
        [cache["features_extractor.cell_list.2.hook_h.0.2"], cache["features_extractor.cell_list.2.hook_c.0.2"]],
        dim=1,
    )
    # Action prediction using "action channels"
    # action_prediction = activations[:, [29, 8, 27, 3]]

    # Action prediction using probes
    action_prediction = th.einsum("nchw,oc->nohw", activations, combined_probe) + combined_intercepts[None, :, None, None]
    # mean-pooling
    num_action1 = action_prediction.mean((2, 3))
    # max-pooling
    num_action2 = action_prediction.max(dim=2, keepdim=False).values.max(dim=2, keepdim=False).values
    # count number of activated
    num_action3 = (action_prediction > 0).float().mean((2, 3))

    num_action = num_action1 * weight[0] + num_action2 * weight[1] + num_action3 * weight[2] + bias

    action = num_action.argmax(1)
    # prediction_record.append(action == mlp_action)
    return action, new_carry


with th.no_grad():
    with model.input_dependent_hooks_context(
        obs,
        fwd_hooks=[
            ("features_extractor.cell_list.2.hook_h.0", _save_fn),
            ("features_extractor.cell_list.2.hook_c.0", _save_fn),
        ],
        bwd_hooks=None,
    ):
        with model.hooks(
            fwd_hooks=[
                ("features_extractor.cell_list.2.hook_h.0.2", _save_fn),
                ("features_extractor.cell_list.2.hook_c.0.2", _save_fn),
            ]
        ):
            metrics = EvalConfig(
                boxo_cfg, level_idxs=args.level_idxs, level_file_idx=args.file_idx, steps_to_think=[args.steps_to_think]
            ).run(get_action_fn, initialize_carry)

            for level_idx, frames in zip(args.level_idxs, metrics[f"{args.steps_to_think:02d}_all_episode_info"]["episode_obs"]):
                frames = np.transpose(frames, (0, 2, 3, 1))
                assert frames.shape[-1] == 3

                scaling_factor = 20  # normal video is too small to play
                resized_frames = [np.repeat(np.repeat(frame, scaling_factor, axis=0), scaling_factor, axis=1) for frame in frames]


                clip = mpy.ImageSequenceClip(list(resized_frames), fps=15)
                clip.write_videofile(f"level_{args.file_idx:03d}_{level_idx:03d}.mp4")

# print("Prediction record: ", np.mean(prediction_record))
