import os
from argparse import ArgumentParser
from collections import deque

import cv2
import gymnasium as gym
import imageio
import numpy as np
import torch as th
from matplotlib import pyplot as plt
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv

from customPolicy import CustomActorCriticPolicy
from protosac import ProtoSAC

device = th.device("cuda" if th.cuda.is_available() else "cpu")

SAVE_DIR = "prototype_visuals"
os.makedirs(SAVE_DIR, exist_ok=True)


class TerminateOnOffTrack(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.frame_stack = deque(maxlen=4)
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(64, 64, 4),  # 4 grayscale frames stacked
            dtype=np.uint8
        )

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        processed_frame = self.preprocess(obs)
        stacked_obs = self.get_stacked_obs(processed_frame)

        # CarRacing-v3 exposes this flag
        off_track = getattr(self.env, "on_grass", False)

        if off_track:
            terminated = True  # end episode immediately
            info["off_track"] = True

        return stacked_obs, reward, terminated, truncated, info

    def preprocess(self, obs):
        # Crop the top 35 pixels (scoreboard), grayscale, resize to 64x64
        # cropped = obs[35:, :, :]
        gray = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        resized = cv2.resize(gray, (64, 64), interpolation=cv2.INTER_AREA)
        return resized.astype(np.uint8)

    def get_stacked_obs(self, new_frame):
        self.frame_stack.append(new_frame)
        while len(self.frame_stack) < 4:
            self.frame_stack.append(new_frame)
        # Stack along the last axis to get shape (64, 64, 4)
        return np.stack(self.frame_stack, axis=-1)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        processed_frame = self.preprocess(obs)
        self.frame_stack.clear()
        stacked_obs = self.get_stacked_obs(processed_frame)
        return stacked_obs, info


def make_env(env_name):
    env = gym.make(env_name, render_mode="rgb_array")
    if env_name == "CarRacing-v3":
        env = TerminateOnOffTrack(env)
    env = Monitor(env)  # record stats such as returns
    return env


if __name__ == "__main__":
    # Argument parsing
    parser = ArgumentParser()
    parser.add_argument('--env', type=str, default="Pendulum-v1",
                        help='OpenAI Gym environment name. Default: Pendulum-v1')
    parser.add_argument('--model_path', type=str, required=True,
                        help='Path to the trained model.')
    parser.add_argument('--episodes', type=int, default=30,
                        help='Number of episodes to evaluate the model. Default: 30')
    parser.add_argument('--save_dir', type=str, default=SAVE_DIR,
                        help='Directory to save the prototype images and video. Default: "prototype_visuals"')

    args = parser.parse_args()

    env = DummyVecEnv([lambda: make_env(args.env)])

    # Load the model from the specified path
    model = ProtoSAC(CustomActorCriticPolicy, env)
    model = ProtoSAC.load(args.model_path)

    means = model.actor.mean
    stds = model.actor.log_std if model.use_sde else model.actor.log_stds
    stds = th.clamp(stds, -20, 2)

    all_states = []
    all_frames = []
    obs = env.reset()
    episodes_done = 0

    while episodes_done < 30:
        done = False
        step = 0

        frame = env.render()
        all_frames.append(frame)
        action, _states = model.predict(obs, deterministic=True)
        all_states.append(obs)
        obs, _, terminated, truncated = env.step(action)
        if terminated:
            obs = env.reset()
            episodes_done += 1

    # Save the frames as a video
    imageio.mimsave(f"{args.save_dir}/video.mp4", all_frames, fps=30)
    print("Done! model video saved.")

    # Evaluate the model using evaluate_policy from stable_baselines3
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=args.episodes, deterministic=True)
    print(f"Model evaluation results: Mean reward = {mean_reward}, Standard deviation = {std_reward}")

    all_states = th.tensor(all_states).float().to(device)
    all_encoded_states = model.state_encoder(all_states)

    # Step 2: Match prototypes to nearest real states
    proto_matches = []  # List of (episode_id, step_id)
    similarity = model.actor.prototype_layer(all_encoded_states).T  # shape will now be (30, 256)
    tau = model.actor.tau
    similarity = similarity[:10, :]
    means = means.detach().cpu().numpy()
    stds = stds.detach().cpu().numpy()

    for i, state in enumerate(similarity):
        idx = th.argmax(state).item()
        frame = all_frames[idx]
        imageio.imwrite(f"{args.save_dir}/proto_{i:02d}.png", frame)
        plt.figure(figsize=(5, 3))
        plt.rcParams.update({'font.size': 14})  # Change 14 to whatever size you need

        # Plot the distribution for each action dimension
        for mean, std in zip(means[i], stds[i]):
            std = float(np.exp(std))
            samples = th.normal(float(mean), std, size=(10000,)).numpy()  # Generate 1000 samples
            num_bins = min(200, len(np.unique(samples)))
            try:
                plt.hist(samples, bins=num_bins, density=True)
            except:
                plt.hist(samples, bins=2, density=True)
            plt.xlabel("Action Value")
            plt.xlim(-10, 10)
            plt.ylabel("Probability Density")

        plt.savefig(f"{args.save_dir}/prototype{i:02d}_action_dist.png", dpi=300, bbox_inches='tight')
        plt.close()

    env.close()
    print("Done! Prototype images saved.")
