import argparse
import os
import random
import time
import sys
import numpy as np
import torch
import torch.nn as nn
import gymnasium as gym
from dataclasses import dataclass
from torch.distributions.normal import Normal
from matplotlib import pyplot as plt
from matplotlib import animation
import tyro

# Add the parent directory to the path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Import environment module to ensure registration
from env.turtlebot_env2 import TurtlebotEnv2
#from env.turtlebot_env import TurtlebotEnv

from adversary.Adversary import ImagePoison, Discrete, Continuous, Dazer
from adversary.OuterLoop import SleeperNets, Learned_Inception
from adversary.InnerLoop import BadRLMiddleMan, TrojDRLMiddleMan, BadBots, OnCeption
from adversary import patterns
from utils.models import Agent, LSTM_Agent, QNetwork
from utils.utils import Args, make_env, load_dict_from_yaml

@dataclass
class Args(Args):
    model_path: str = ""
    num_episodes: int = 5
    poison: bool = False

def save_frames_as_gif(frames, path='./', filename='gym_animation.gif', dpi = 100):

    #Mess with this to change frame size
    plt.figure(figsize=(4,4), dpi=int(dpi))

    patch = plt.imshow(frames[0], vmin = 0, vmax = 1, cmap="inferno")
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    anim.save(path + filename, writer='imagemagick', fps=30)

# def make_env(env_id, idx, capture_video, run_name, gamma, args):
#     def thunk():
#         # Prepare kwargs based on environment type
#         env_kwargs = {'gui': args.gui}  # GUI is common for all envs
        
#         # Add specific arguments based on env_id
#         if 'turtlebot' in env_id.lower():
#             env_kwargs.update({
#                 'tb2_speed': args.tb2_speed,
#                 'tb3_speed': args.tb3_speed
#             })
        
#         env = gym.make(
#             env_id,
#             render_mode="human" if args.gui else None,
#             **env_kwargs
#         )
        
#         if capture_video:
#             env = gym.wrappers.RecordVideo(
#                 env, 
#                 f"videos/{run_name}",
#                 episode_trigger=lambda x: x == 0
#             )
#         return env
#     return thunk

def main():
    args = tyro.cli(Args)
    #args.gui = False
    #args.egl = False
    # args.real = True
    
    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    # Create a run name for video capture
    run_name = f"{args.env_id}_eval_{int(time.time())}"
    
    envs = gym.vector.SyncVectorEnv(
        [make_env(
            args.env_id, 
            0, 
            args.capture_video, 
            run_name, 
            0.99,  # gamma value, not critical for evaluation
            args   # Pass all args to make_env
        )]
    )

    if args.lstm:
        agent = LSTM_Agent(envs, args)
    else:
        agent = Agent(envs, args)
    device = "cuda"
    
    pattern = patterns.RobustTrigger((84,84), 8, 32, 0.25, .75, args.num_frames, edge = args.edge, fixed_pos = [36,42])
    #pattern = patterns.RobustTrigger((84,84), 8,  16, 0, 0.2, args.num_frames, edge= args.edge)
    trigger = trigger = lambda x: pattern(x, True)#Dazer("simplex", (args.num_frames, 84,84), noise_magnitude=0.05, flat = True)
    # --- Set up Outer Loop Attacks --- #
    if args.poison:
        if args.robust:
            pattern = patterns.RobustTrigger((84,84), 8, 16, 0, .8, args.num_frames, fixed_pos=[12,42],  edge = args.edge)
            #pattern = patterns.RobustTrigger((84,84), 8,  16, 0, 0.2, args.num_frames, edge= args.edge)
            trigger = lambda x: pattern(x, True)
        else:
            pattern = patterns.Single_Stacked_Img_Pattern((args.num_frames, 84, 84), (8,8), min=-1, max = 1).to(device)
            pattern = pattern.flatten()
            trigger = ImagePoison(pattern, 0, 1)

    print(f"Loading model from {args.model_path}")
    agent.load_state_dict(torch.load(args.model_path, map_location=torch.device('cpu')))
    agent.eval()  # Set to evaluation mode
    agent = agent.to(device)

    # Run episodes
    total_rewards = []
    success_count = 0
    frames = []
    actions = []

    for episode in range(args.num_episodes):
        obs, _ = envs.reset()
        done = torch.tensor([False]).to(device)
        episode_reward = 0
        step_count = 0
        
        print(f"Starting episode {episode+1}/{args.num_episodes}")

        if args.lstm:
            lstm_state = (
                torch.zeros(agent.lstm.num_layers, args.num_envs, agent.lstm.hidden_size).to(device),
                torch.zeros(agent.lstm.num_layers, args.num_envs, agent.lstm.hidden_size).to(device),
            )  # hidden and cell states (see https://youtu.be/8HyCNIVRbSU)
        
        while not done[0]:
            if args.poison:
                obs = trigger(torch.tensor(obs).cuda()).cpu().numpy()
            

            # Convert observation to tensor
            obs_tensor = torch.FloatTensor(obs).cuda()

            #print(torch.min(obs_tensor), torch.max(obs_tensor))
            
            # Get action from policy
            with torch.no_grad():
                action, _ = agent.get_mean_std(obs_tensor)
                #action, _, _, _ = agent.get_action_and_value(obs_tensor)
                action = torch.clamp(action, -1, 1)
                #if args.lstm:
                #    action, _, _, _, lstm_state = agent.get_action_and_value(obs_tensor, lstm_state, done)
                #else:
                    #action, _ = agent.get_mean_std(obs_tensor)
                #    action, _, _, _ = agent.get_action_and_value(obs_tensor)
                
            if episode==0:
               temp = trigger(obs)
               print(obs.shape)
               for iter in range(8):
                   image2 = np.reshape(temp[:,7056*iter:7056*(iter+1)], (84,84))
                   plt.figure(dpi = 150)
                   plt.imshow(image2)
                   plt.savefig(f"gifs/{iter}.png")
                   plt.close()
               input("wait")
               image = np.reshape(temp[:,:7056], (84,84))
               frames.append(image)
            actions.append(action.cpu().numpy())

            # Execute action in environment
            action = action.cpu().numpy()
            obs, reward, terminations, truncations, infos = envs.step(action)
            
            done = torch.tensor(np.logical_or(terminations, truncations)).to(device)
            episode_reward += reward[0]
            step_count += 1
            
            # Print progress
            if step_count % 10 == 0:
                print(f"Episode {episode+1}, Step {step_count}, Current reward: {episode_reward:.2f}", end = "\r")
            
            # Check for episode end
            if done[0]:
                if "final_info" in infos:
                    for info in infos["final_info"]:
                        if info and "reason" in info:
                            print(f"Episode end reason: {info['reason']}")
                            if info["reason"] == "success":
                                success_count += 1
        if episode==0:
           save_frames_as_gif(frames, "gifs/", "test.gif", dpi = 100)
           #save_frames_as_gif(frames2, "gifs/", "test2.gif", dpi = 100)

        # Episode summary
        total_rewards.append(episode_reward)
        print(f"\nEpisode {episode+1} finished: Steps={step_count}, Reward={episode_reward:.2f}")
        
        # Short pause between episodes
        time.sleep(1)
    
    actions = np.array(actions)
    plt.figure(dpi = 150)
    plt.hist(actions[:,0, 0], label = "linear", alpha = 0.5, bins = 20)
    plt.hist(actions[:,0, 1], label = "angular", alpha = 0.5, bins = 20)
    plt.legend()
    plt.savefig("gifs/poisoned.png" if args.poison else "gifs/benign.png")
    plt.close()

    # Final summary
    print("\nRun complete!")
    print(f"Average reward over {args.num_episodes} episodes: {np.mean(total_rewards):.2f}")
    print(f"Success rate: {success_count}/{args.num_episodes} ({100*success_count/args.num_episodes:.1f}%)")
    print(f"Rewards: {total_rewards}")
    
    envs.close()

if __name__ == "__main__":
    main()
