import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import os
import argparse
from functools import partial

import navix as nx
from src.envs.gymnax_wrappers import NavixToGymnax, PixelNoise
from src.policies.navix_policies import (
    random_policy, 
    navigation_policy, 
    key_collection_policy, 
    door_opening_policy, 
    goal_reaching_policy, 
    doorkey_policy,
    run_episode,
    create_gif
)

def create_output_dir(dir_name="outputs"):
    """Create output directory if it doesn't exist."""
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    return dir_name

def setup_environment(env_name, obs_type="symbolic", img_size=64, render_mode="rgb"):
    """Setup a Navix environment with the specified parameters."""
    observation_fns = {
        "categorical": nx.observations.categorical,
        "symbolic": nx.observations.symbolic,
        "rgb": nx.observations.rgb
    }
    
    # Create base environment
    env = nx.make(
        f"Navix-{env_name}",
        observation_fn=observation_fns[obs_type],
        action_set=nx.actions.COMPLETE_ACTION_SET
    )
    
    # Normalize observations if needed
    if obs_type == "rgb":
        def normalized_rgb(rgb_fn, size):
            def _normalized_rgb_fn(state):
                obs = rgb_fn(state)
                obs = jax.image.resize(obs, (size, size, 3), method="bilinear")
                obs = obs / 255
                return obs
            return _normalized_rgb_fn
        
        env = env.replace(
            observation_fn=normalized_rgb(observation_fns[obs_type], img_size),
            observation_space=nx.spaces.Continuous.create((img_size, img_size, 3), 0, 1)
        )
    elif obs_type == "symbolic":
        def normalized_symbolic(sym_fn):
            def _normalize_sym_fn(state):
                obs = sym_fn(state)
                obs = obs / 10
                return obs
            return _normalize_sym_fn
        env = env.replace(
            observation_fn=normalized_symbolic(observation_fns[obs_type])
        )
    
    # Convert to Gymnax-compatible environment
    basic_env = NavixToGymnax(env)
    
    # Add noise if needed
    noise_sigma = 1e-3  # Small amount of noise
    if obs_type == "rgb":
        basic_env = PixelNoise(basic_env, noise_sigma=noise_sigma)
    
    return basic_env

def run_all_demos(output_dir, seed=0):
    """Run demonstrations for all policies."""
    env_configs = [
        {"name": "DoorKey-5x5-v0", "policy": doorkey_policy, "max_steps": 50},
        {"name": "DoorKey-8x8-v0", "policy": doorkey_policy, "max_steps": 100},
        {"name": "DoorKey-16x16-v0", "policy": doorkey_policy, "max_steps": 250},
        {"name": "DoorKey-Random-8x8-v0", "policy": doorkey_policy, "max_steps": 100},
    ]
    
    print("Starting policy demonstrations...")
    
    for config in env_configs:
        env_name = config["name"]
        policy_fn = config["policy"]
        max_steps = config["max_steps"]
        
        print(f"\nRunning demonstration for {env_name} with {policy_fn.__name__}")
        
        # Create environment
        env = setup_environment(env_name, obs_type="rgb", img_size=128)
        
        # Run episode
        frames, total_reward = run_episode(
            env=env,
            policy_fn=policy_fn,
            max_steps=max_steps,
            render=True,
            seed=seed
        )
        
        # Create and save GIF
        gif_path = os.path.join(output_dir, f"{env_name}_{policy_fn.__name__}.gif")
        create_gif(frames, filename=gif_path, fps=5)
        
        print(f"Completed episode with total reward: {total_reward}")
        print(f"Saved GIF to {gif_path}")
    
    print("\nAll demonstrations completed!")

def run_specific_demo(env_name, policy_name, output_dir, max_steps=100, seed=0):
    """Run a specific policy demonstration."""
    policy_map = {
        "random": random_policy,
        "navigation": navigation_policy,
        "key_collection": key_collection_policy,
        "door_opening": door_opening_policy,
        "goal_reaching": goal_reaching_policy,
        "doorkey": doorkey_policy
    }
    
    if policy_name not in policy_map:
        raise ValueError(f"Unknown policy: {policy_name}. Available policies: {list(policy_map.keys())}")
    
    policy_fn = policy_map[policy_name]
    
    print(f"\nRunning demonstration for {env_name} with {policy_name} policy")
    
    # Create environment
    env = setup_environment(env_name, obs_type="rgb", img_size=128)
    
    # Run episode
    frames, total_reward = run_episode(
        env=env,
        policy_fn=policy_fn,
        max_steps=max_steps,
        render=True,
        seed=seed
    )
    
    # Create and save GIF
    gif_path = os.path.join(output_dir, f"{env_name}_{policy_name}.gif")
    create_gif(frames, filename=gif_path, fps=5)
    
    print(f"Completed episode with total reward: {total_reward}")
    print(f"Saved GIF to {gif_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Navix policy demonstrations")
    parser.add_argument("--output-dir", type=str, default="outputs/navix_demos", 
                        help="Directory to save output files")
    parser.add_argument("--env", type=str, default="DoorKey-8x8-v0", 
                        help="Environment name")
    parser.add_argument("--policy", type=str, default="doorkey", 
                        help="Policy to use (random, navigation, key_collection, door_opening, goal_reaching, doorkey)")
    parser.add_argument("--max-steps", type=int, default=100, 
                        help="Maximum number of steps per episode")
    parser.add_argument("--seed", type=int, default=0, 
                        help="Random seed")
    parser.add_argument("--all", action="store_true", 
                        help="Run all demonstrations")
    
    args = parser.parse_args()
    
    # Create output directory
    output_dir = create_output_dir(args.output_dir)
    
    if args.all:
        run_all_demos(output_dir, seed=args.seed)
    else:
        run_specific_demo(
            env_name=args.env, 
            policy_name=args.policy, 
            output_dir=output_dir, 
            max_steps=args.max_steps, 
            seed=args.seed
        ) 