import argparse
import os
import random
import time
import sys
import numpy as np
import torch
import torch.nn as nn
import gymnasium as gym
from torch.distributions.normal import Normal
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from dataclasses import dataclass
import tyro

# Add the parent directory to the path FIRST
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Now import from the project modules
from adversary.Adversary import ImagePoison, Discrete, Continuous, exp_cos, Dazer, cos_dist_np, l2dist, log_dist
from adversary.OuterLoop import SleeperNets, Learned_Inception
from adversary.InnerLoop import BadRLMiddleMan, TrojDRLMiddleMan, BadBots, OnCeption
from env.adversarial_mpd import DAZE_Outer
from adversary import patterns
from utils.models import Agent, LSTM_Agent, QNetwork
from utils.utils import Args, make_env, load_dict_from_yaml
from utils.vit import ViT_Agent
from env.fetch_env import FetchEnv

@dataclass
class EvalArgs(Args):
    model_path: str = ""
    num_episodes: int = 5
    poison: bool = False
    model_num_frames: int = 8

def save_frames_as_gif(frames, path='./', filename='gym_animation.gif', dpi=100):
    from PIL import Image
    import os
    
    os.makedirs(path, exist_ok=True)
    
    # Subsample frames if too many
    if len(frames) > 100:
        step = len(frames) // 100
        frames = frames[::step]
        print(f"Subsampled to {len(frames)} frames")
    
    pil_frames = []
    for frame in frames:
        frame = np.array(frame)
        if frame.max() <= 1.0:
            frame = (frame * 255).astype(np.uint8)
        pil_frames.append(Image.fromarray(frame))
    
    pil_frames[0].save(
        path + filename,
        save_all=True,
        append_images=pil_frames[1:],
        duration=100,
        loop=0
    )
    print(f"Saved GIF: {path + filename}")

def main():
    args = tyro.cli(EvalArgs)
    
    # Override num_frames to match the trained model
    args.num_frames = args.model_num_frames
    print(f"Using num_frames: {args.num_frames}")
    
    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create environment
    run_name = f"{args.env_id}_eval_{int(time.time())}"
    args.env_id = "fetch-v0"  # Force fetch environment
    
    envs = gym.vector.SyncVectorEnv(
        [make_env(
            args.env_id, 
            0, 
            args.capture_video, 
            run_name, 
            args.gamma,
            args
        )]
    )

    # Create agent with correct num_frames
    if args.lstm:
        agent = LSTM_Agent(envs, args)
    elif args.vit:
        agent = ViT_Agent(
            envs,
            image_size=84,
            patch_size=4,
            num_classes=2,
            dim=336,
            depth=3,
            heads=8,
            mlp_dim=672,
            channels=args.num_frames,
            dim_head=32
        )
    else:
        agent = Agent(envs, args)

    print(f"Loading model from {args.model_path}")
    agent.load_state_dict(torch.load(args.model_path, map_location=device))
    agent.eval()
    agent = agent.to(device)

    # Setup poisoning trigger (like in run_policy.py)
    trigger = None
    if args.poison:
        if args.robust:
            pattern = patterns.RobustTrigger((84,84), 8, 32, 0.25, .75, args.num_frames, edge = args.edge, fixed_pos = [36,42])
            trigger = lambda x: pattern(x, True)
        else:
            pattern = patterns.Single_Stacked_Img_Pattern((args.num_frames, 84, 84), (8,8), min=-1, max=1).to(device)
            pattern = pattern.flatten()
            trigger = ImagePoison(pattern, 0, 1)
        print(f"Poisoning enabled: robust={args.robust}, edge={args.edge}")

    # Create gifs directory
    os.makedirs('./gifs/', exist_ok=True)
    
    # Run episodes
    total_rewards = []
    success_count = 0
    frames = []
    actions = []

    for episode in range(args.num_episodes):
        obs, _ = envs.reset()
        done = torch.tensor([False]).to(device)
        episode_reward = 0
        step_count = 0
        
        print(f"Starting episode {episode+1}/{args.num_episodes}")

        # Initialize LSTM state if needed
        if args.lstm:
            lstm_state = (
                torch.zeros(agent.lstm.num_layers, 1, agent.lstm.hidden_size).to(device),
                torch.zeros(agent.lstm.num_layers, 1, agent.lstm.hidden_size).to(device),
            )
        
        while not done[0]:
            # Apply poisoning (like in run_policy.py)
            if args.poison and trigger is not None:
                obs = trigger(torch.tensor(obs).to(device)).cpu().numpy()

            # Convert observation to tensor
            obs_tensor = torch.FloatTensor(obs).to(device)
            
            # Get action from policy
            with torch.no_grad():
                if args.lstm:
                    action, _, _, _, lstm_state = agent.get_action_and_value(obs_tensor, lstm_state, done)
                else:
                    action, _, _, _ = agent.get_action_and_value(obs_tensor)
                action = torch.clamp(action, -1, 1)
                
            # Capture frames for first episode (like in run_policy.py)
            if episode == 0:
                # Extract image from observation (depth data)
                obs_reshaped = obs.reshape(args.num_frames, 84, 84)
                latest_frame = obs_reshaped[-1]  # Get most recent frame
                
                # Normalize for visualization
                if latest_frame.max() > latest_frame.min():
                    frame_normalized = (latest_frame - latest_frame.min()) / (latest_frame.max() - latest_frame.min())
                else:
                    frame_normalized = latest_frame
                
                frames.append(frame_normalized)
            
            # Store actions for analysis
            actions.append(action.cpu().numpy())

            # Execute action in environment
            action = action.cpu().numpy()
            obs, reward, terminations, truncations, infos = envs.step(action)
            
            done = torch.tensor(np.logical_or(terminations, truncations)).to(device)
            episode_reward += reward[0]
            step_count += 1
            
            # Print progress
            if step_count % 10 == 0:
                print(f"Episode {episode+1}, Step {step_count}, Current reward: {episode_reward:.2f}", end="\r")
            
            # Check for episode end
            if done[0]:
                if "final_info" in infos:
                    for info in infos["final_info"]:
                        if info and "reason" in info:
                            print(f"\nEpisode end reason: {info['reason']}")
                            if info["reason"] == "success":
                                success_count += 1

        # Save GIF for first episode (like in run_policy.py)
        if episode == 0 and frames:
            gif_name = "poisoned_fetch.gif" if args.poison else "benign_fetch.gif"
            save_frames_as_gif(frames, "gifs/", gif_name, dpi=100)

        # Episode summary
        total_rewards.append(episode_reward)
        print(f"\nEpisode {episode+1} finished: Steps={step_count}, Reward={episode_reward:.2f}")
        time.sleep(1)
    
    # Create action analysis (like in run_policy.py)
    actions = np.array(actions)
    print(f"Actions shape: {actions.shape}")
    
    plt.figure(dpi=150)
    plt.hist(actions[:, 0, 0], label="linear", alpha=0.5, bins=20)
    plt.hist(actions[:, 0, 1], label="angular", alpha=0.5, bins=20)
    plt.legend()
    plt.title("Action Distribution")
    plt.xlabel("Action Value")
    plt.ylabel("Frequency")
    
    plot_name = "gifs/poisoned_actions.png" if args.poison else "gifs/benign_actions.png"
    plt.savefig(plot_name)
    plt.close()
    print(f"Saved action histogram: {plot_name}")

    # Final summary
    print("\nRun complete!")
    print(f"Average reward over {args.num_episodes} episodes: {np.mean(total_rewards):.2f}")
    print(f"Success rate: {success_count}/{args.num_episodes} ({100*success_count/args.num_episodes:.1f}%)")
    print(f"Rewards: {total_rewards}")
    
    envs.close()

if __name__ == "__main__":
    main()