import os
import gymnasium as gym
import imageio
import numpy as np
import torch as th
from matplotlib import pyplot as plt
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from customPolicy import CustomActorCriticPolicy
from protosac import ProtoSAC
from argparse import ArgumentParser
from stable_baselines3.common.evaluation import evaluate_policy

device = th.device("cuda" if th.cuda.is_available() else "cpu")

SAVE_DIR = "prototype_visuals"
os.makedirs(SAVE_DIR, exist_ok=True)


def make_env(env_name):
    env = gym.make(env_name, render_mode="rgb_array")
    env = Monitor(env)  # record stats such as returns
    return env


if __name__ == "__main__":
    # Argument parsing
    parser = ArgumentParser()
    parser.add_argument('--environment', type=str, default="Pendulum-v1",
                        help='OpenAI Gym environment name. Default: Pendulum-v1')
    parser.add_argument('--model_path', type=str, required=True,
                        help='Path to the trained model.')
    parser.add_argument('--episodes', type=int, default=30,
                        help='Number of episodes to evaluate the model. Default: 30')
    parser.add_argument('--save_dir', type=str, default=SAVE_DIR,
                        help='Directory to save the prototype images and video. Default: "prototype_visuals"')

    args = parser.parse_args()

    env = DummyVecEnv([lambda: make_env(args.environment)])

    # Load the model from the specified path
    model = ProtoSAC(CustomActorCriticPolicy, env)
    model = ProtoSAC.load(args.model_path)

    means = model.actor.mean
    stds = model.actor.log_std if model.use_sde else model.actor.log_stds
    stds = th.clamp(stds, -20, 2)

    all_states = []
    all_frames = []
    obs = env.reset()
    episodes_done = 0

    while episodes_done < 30:
        done = False
        step = 0

        frame = env.render()
        all_frames.append(frame)
        action, _states = model.predict(obs, deterministic=True)
        all_states.append(obs)
        obs, _, terminated, truncated = env.step(action)
        if terminated:
            obs = env.reset()
            episodes_done += 1

    # Save the frames as a video
    imageio.mimsave(f"{args.save_dir}/video.mp4", all_frames, fps=30)
    print("Done! model video saved.")

    # Evaluate the model using evaluate_policy from stable_baselines3
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=args.episodes, deterministic=True)
    print(f"Model evaluation results: Mean reward = {mean_reward}, Standard deviation = {std_reward}")

    all_states = th.tensor(all_states).float().to(device)
    all_encoded_states = model.state_encoder(all_states)

    # Step 2: Match prototypes to nearest real states
    proto_matches = []  # List of (episode_id, step_id)
    similarity = model.actor.prototype_layer(all_encoded_states).T  # shape will now be (30, 256)
    tau = model.actor.tau
    similarity = similarity[:10, :]
    means = means.detach().cpu().numpy()
    stds = stds.detach().cpu().numpy()

    for i, state in enumerate(similarity):
        idx = th.argmax(state).item()
        frame = all_frames[idx]
        imageio.imwrite(f"{args.save_dir}/proto_{i:02d}.png", frame)
        plt.figure(figsize=(5, 3))
        plt.rcParams.update({'font.size': 14})  # Change 14 to whatever size you need

        # Plot the distribution for each action dimension
        for mean, std in zip(means[i], stds[i]):
            std = float(np.exp(std))
            samples = th.normal(float(mean), std, size=(10000,)).numpy()  # Generate 1000 samples
            num_bins = min(200, len(np.unique(samples)))
            try:
                plt.hist(samples, bins=num_bins, density=True)
            except:
                plt.hist(samples, bins=2, density=True)
            plt.xlabel("Action Value")
            plt.ylabel("Probability Density")

        plt.savefig(f"{args.save_dir}/prototype{i:02d}_action_dist.png", dpi=300, bbox_inches='tight')
        plt.close()

    env.close()
    print("Done! Prototype images saved.")

