import os
import gym
import torch
import numpy as np
import matplotlib.pyplot as plt
import glob
import sys
from dataclasses import dataclass, field # Import dataclasses
import tyro # Import tyro if needed to parse defaults, though loading from file is better
import imageio.v2 as imageio # Import imageio for GIF creation
import fnmatch # Import fnmatch for pattern matching

# --- Import necessary components from the original script ---
# Assuming sac_continuous_action_dexgym_lle.py is in the same directory or Python path
try:
    import sac_continuous_action_dexgym_lle
    from sac_continuous_action_dexgym_lle import Actor, LLE, make_env, Args # Import Args dataclass
except ImportError:
    print("Error: Could not import sac_continuous_action_dexgym_lle. Make sure the file is in the same directory or in your Python path.")
    sys.exit(1)

# --- Helper function to load args ---
def load_args_from_file(filepath):
    args_dict = {}
    # Get default values from Args dataclass first
    default_args = tyro.cli(Args, default={}, exit_on_error=False) # Get defaults without exiting
    args_dict.update(vars(default_args))

    # Override with values from file
    try:
        with open(filepath, 'r') as f:
            for line in f:
                if ':' in line:
                    key, value = line.strip().split(':', 1)
                    key = key.strip()
                    value = value.strip()
                    # Attempt to convert value to appropriate type based on default
                    original_type = type(args_dict.get(key, '')) # Default to string type if key not in defaults
                    try:
                        if value == 'None':
                             value = None
                        elif original_type == bool:
                             value = value.lower() == 'true'
                        elif original_type == int:
                             value = int(value)
                        elif original_type == float:
                             value = float(value)
                        # Add other type conversions if needed
                        else: # Keep as string if no specific type match or it's already string
                             value = str(value)

                    except (ValueError, TypeError):
                         print(f"Warning: Could not convert key '{key}' value '{value}' to type {original_type}. Keeping as string.")
                         value = str(value) # Keep as string on error
                    args_dict[key] = value
    except FileNotFoundError:
        print(f"Warning: args.txt file not found at {filepath}. Using default Args.")
        # If file not found, args_dict already contains defaults

    # Create an object that mimics the args namespace
    class ArgsNamespace:
        def __init__(self, **kwargs):
            self.__dict__.update(kwargs)
        def __getattr__(self, name):
             # Return None or raise error if attribute doesn't exist?
             # Let's return None to avoid breaking LLE if an arg is missing,
             # though this might hide issues.
             # print(f"Warning: Accessing potentially missing arg '{name}'") # Uncomment for debugging
             return None
    return ArgsNamespace(**args_dict)
# --- End Helper function ---

# --- Attention map configuration ---
ATTENTION_PERCENTAGES = [0.1, 0.5, 0.9]

def save_attention_plot(attention_map, env_id, seed, step_number, percentage_label, save_dir):
    if attention_map is None or attention_map.size == 0:
        print(f"Warning: Empty attention map for {env_id}, seed {seed} at {percentage_label}. Skipping save.")
        return

    plt.figure(figsize=(12, 10))
    im = plt.imshow(attention_map, cmap='viridis', aspect='auto', origin='lower')
    cbar = plt.colorbar(im)
    cbar.ax.tick_params(labelsize=16)
    plt.title(f"Attention Map - Step {step_number}", fontsize=20)
    plt.xlabel("Attention Heads / Features", fontsize=18)
    plt.ylabel("Query Features", fontsize=18)

    num_features = attention_map.shape[0]
    num_reward_features = 128
    if num_features >= 2 * num_reward_features:
        ax = plt.gca()
        tick_positions = [0, num_reward_features, num_features - 1]
        tick_labels = ['0', str(num_reward_features), str(num_features - 1)]
        plt.xticks(tick_positions, tick_labels, fontsize=16)
        plt.yticks(tick_positions, tick_labels, fontsize=16)

        ax.axhline(y=num_reward_features - 0.5, color='gray', linestyle='--', linewidth=1)
        ax.axvline(x=num_reward_features - 0.5, color='gray', linestyle='--', linewidth=1)

        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        text_offset_x = (xlim[1] - xlim[0]) * 0.08
        text_offset_y = (ylim[1] - ylim[0]) * 0.08
        font_size = 18

        plt.text(xlim[0] - text_offset_x, (num_reward_features - 0.5) / 2, "Reward Features",
                 rotation=90, va='center', ha='right', fontsize=font_size)
        plt.text(xlim[0] - text_offset_x, num_reward_features - 0.5 + (num_features - num_reward_features + 0.5) / 2, "LLE Features",
                 rotation=90, va='center', ha='right', fontsize=font_size)

        plt.text((num_reward_features - 0.5) / 2, ylim[0] - text_offset_y, "Reward Features",
                 ha='center', va='top', fontsize=font_size)
        plt.text(num_reward_features - 0.5 + (num_features - num_reward_features + 0.5) / 2, ylim[0] - text_offset_y, "LLE Features",
                 ha='center', va='top', fontsize=font_size)

        plt.subplots_adjust(left=0.2, bottom=0.2)
    else:
        print(f"Warning: Attention map size ({num_features}) is less than 256. Skipping feature labels.")

    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    save_filename = f"{env_id}_seed{seed}_{percentage_label}_step{step_number:03d}.pdf"
    save_path = os.path.join(save_dir, save_filename)
    plt.savefig(save_path)
    plt.close()
    print(f"Saved attention map for {env_id} seed {seed} at {percentage_label} (step {step_number}) to {save_path}")

# --- Configuration for finding runs (Edit these values) ---
V_SEARCH = "1"
LWS_SEARCH = "5"
BS_SEARCH = "2048"
LRW_SEARCH = "1e-2"
LRPHI_SEARCH = "1e-3"
GSW_SEARCH = "1e-15"
GSPHI_SEARCH = "1e-10"
LRTR_SEARCH = "1e-3"

# --- Define environments and seeds to loop over ---
envs_to_process = [
    "EggCatchUnderarm-v0",
    # "BlockCatchUnderarm-v0",
    # "PenCatchUnderarm-v0",
    # "EggCatchOverarm-v0",
    # "BlockCatchOverarm-v0",
    # "PenCatchOverarm-v0",
]
seeds_to_process = range(1,2)

# --- Automatic Run Directory Finding and Processing Loop ---
run_dir_base = "runs-dexgym"

for ENV_ID_SEARCH in envs_to_process:
    for SEED_SEARCH_INT in seeds_to_process:
        SEED_SEARCH = str(SEED_SEARCH_INT) # Convert seed to string for pattern matching

        print(f"\n--- Processing Environment: {ENV_ID_SEARCH}, Seed: {SEED_SEARCH} ---")

        # Construct the pattern to search for (timestamp is always wildcard)
        run_pattern = f"{ENV_ID_SEARCH}__v{V_SEARCH}_sac_lle_dexgym_{ENV_ID_SEARCH}_lws{LWS_SEARCH}_bs{BS_SEARCH}_lrW{LRW_SEARCH}_lrPhi{LRPHI_SEARCH}_gsW{GSW_SEARCH}_gsPhi{GSPHI_SEARCH}_lrTr{LRTR_SEARCH}_seed{SEED_SEARCH}__{SEED_SEARCH}__*"
        print(f"Searching for run directories matching pattern: {run_pattern} in {run_dir_base}")

        matching_runs = []
        try:
            # List all items in the base run directory
            all_items = os.listdir(run_dir_base)
            # Filter for directories that match the pattern
            for item in all_items:
                item_path = os.path.join(run_dir_base, item)
                if os.path.isdir(item_path):
                    if fnmatch.fnmatch(item, run_pattern):
                        matching_runs.append(item)
        except FileNotFoundError:
            print(f"Error: Directory '{run_dir_base}' not found. Skipping {ENV_ID_SEARCH}, Seed {SEED_SEARCH}.")
            continue # Skip to the next iteration
        except Exception as e:
            print(f"An error occurred while searching for directories for {ENV_ID_SEARCH}, Seed {SEED_SEARCH}: {e}. Skipping.")
            continue # Skip to the next iteration

        if not matching_runs:
            print(f"No run directories found matching the pattern: {run_pattern}. Skipping {ENV_ID_SEARCH}, Seed {SEED_SEARCH}.")
            continue # Skip to the next iteration
        elif len(matching_runs) > 1:
            print(f"Found multiple run directories matching the pattern: {run_pattern}")
            print("Please refine your search parameters to select a single run for this env/seed combination. Skipping:")
            for run in matching_runs:
                print(f"- {run}")
            continue # Skip to the next iteration
        else:
            # Exactly one run found, use it
            run_name = matching_runs[0]
            run_dir = os.path.join(run_dir_base, run_name)
            print(f"Automatically selected run directory: {run_name}")

        # --- Configuration ---
        args_file = os.path.join(run_dir, "args.txt")
        actor_path = os.path.join(run_dir, "actor.pth")
        lle_model_path = os.path.join(run_dir, "lle_model.pth")
        # Make output directory specific to the run
        output_dir = f"attention_maps/attention_maps_{run_name}"
        os.makedirs(output_dir, exist_ok=True)
        plot_maps_dir = os.path.join(output_dir, "plot-maps")
        gif_filename = f"attention_maps_animation_{run_name}.gif"

        # --- Load Args and Set Up ---
        # 1. Load args from the file
        loaded_args = load_args_from_file(args_file)

        # 2. Inject the loaded args into the imported module's namespace
        # Ensure env_id and seed from loaded args match the current loop iteration (should be the case if run_name matches)
        if loaded_args.env_id != ENV_ID_SEARCH or str(loaded_args.seed) != SEED_SEARCH:
            print(f"Warning: Loaded args (env_id: {loaded_args.env_id}, seed: {loaded_args.seed}) do not match expected (env_id: {ENV_ID_SEARCH}, seed: {SEED_SEARCH}). Proceeding with loaded args.")

        sac_continuous_action_dexgym_lle.args = loaded_args
        print("Injected 'args' into sac_continuous_action_dexgym_lle module.")

        # 3. Set device based on loaded args
        device = torch.device("cuda" if torch.cuda.is_available() and loaded_args.cuda else "cpu")
        print(f"Using device: {device}")

        # 4. Initialize Environment using loaded args
        # Use the seed from loaded_args to ensure consistency with the loaded model
        envs = gym.vector.SyncVectorEnv([make_env(loaded_args.env_id, loaded_args.seed, 0, False, "plot_run")])
        print(f"Initialized environment: {loaded_args.env_id} with seed {loaded_args.seed}")

        # 5. Inject the created envs object into the imported module's namespace
        sac_continuous_action_dexgym_lle.envs = envs
        print("Injected 'envs' into sac_continuous_action_dexgym_lle module.")

        # 6. Load Actor Model
        actor = Actor(envs)
        try:
            actor.load_state_dict(torch.load(actor_path, map_location=device))
            actor.to(device)
            actor.eval()
            print("Actor model loaded.")
        except FileNotFoundError:
            print(f"Error: Actor model not found at {actor_path}. Skipping {ENV_ID_SEARCH}, Seed {SEED_SEARCH}.")
            envs.close() # Clean up environment
            continue # Skip to the next iteration
        except Exception as e:
            print(f"Error loading Actor model for {ENV_ID_SEARCH}, Seed {SEED_SEARCH}: {e}. Skipping.")
            envs.close() # Clean up environment
            continue # Skip to the next iteration

        # 7. Load LLE Model - it will now find sac_continuous_action_dexgym_lle.args
        lle_model = LLE(envs) # LLE.__init__ will use sac_continuous_action_dexgym_lle.args
        try:
            lle_model.load_state_dict(torch.load(lle_model_path, map_location=device))
            lle_model.to(device)
            lle_model.eval()
            print("LLE model loaded.")
        except FileNotFoundError:
            print(f"Error: LLE model not found at {lle_model_path}. Skipping {ENV_ID_SEARCH}, Seed {SEED_SEARCH}.")
            envs.close() # Clean up environment
            continue # Skip to the next iteration
        except Exception as e:
            print(f"Error loading LLE model for {ENV_ID_SEARCH}, Seed {SEED_SEARCH}: {e}. Skipping.")         
            envs.close() # Clean up environment
            continue # Skip to the next iteration


        # --- Evaluation ---
        print("\n--- Starting Evaluation ---")
        num_eval_episodes = 1 # Set to 10 episodes as requested
        eval_returns = []
        # Use a separate environment for evaluation with a fixed seed
        eval_envs = gym.vector.SyncVectorEnv([make_env(loaded_args.env_id, 100, 0, False, "eval_run")]) # Using seed 100 for eval

        # Ensure output directory for plots exists
        os.makedirs(plot_maps_dir, exist_ok=True)
        print(f"Saving attention maps at 10%, 50%, and 90% progress to: {plot_maps_dir}")


        for episode in range(num_eval_episodes):
            print(f"Running evaluation episode {episode + 1}/{num_eval_episodes}")
            obs = eval_envs.reset()
            done = np.array([False])
            total_return = 0
            eval_step_count = 0
            max_eval_steps = 1000 # Limit evaluation episode length
            attention_snapshots = []

            while not done.any() and eval_step_count < max_eval_steps:
                obs_tensor = torch.Tensor(obs).to(device)
                with torch.no_grad():
                    lle_features = lle_model(obs_tensor)
                    # Get action from the loaded actor model
                    action, _, _ = actor.get_action(obs_tensor, lle_features) # Use evaluate=True if your get_action has it

                if hasattr(actor, 'actor_attention_weights') and actor.actor_attention_weights is not None:
                    attention_map = actor.actor_attention_weights.detach().cpu().numpy().squeeze(0)
                    if attention_map is not None and attention_map.size > 0:
                        attention_snapshots.append((eval_step_count, attention_map.copy()))
                    else:
                        print(f"Warning: Invalid attention map at Eval Ep {episode + 1}, Step {eval_step_count} for {ENV_ID_SEARCH}, Seed {SEED_SEARCH}.")

                obs, reward, done, info = eval_envs.step(action.cpu().numpy())
                total_return += reward.sum() # Sum rewards across potentially multiple envs (though here it's 1)
                eval_step_count += 1

            if attention_snapshots:
                max_step_index = max(eval_step_count - 1, 0)
                for percentage in ATTENTION_PERCENTAGES:
                    target_step = int(max_step_index * percentage)
                    closest_step, closest_map = min(attention_snapshots, key=lambda snap: abs(snap[0] - target_step))
                    percentage_label = f"{int(percentage * 100):02d}%"
                    save_attention_plot(
                        closest_map,
                        ENV_ID_SEARCH,
                        SEED_SEARCH,
                        closest_step,
                        percentage_label,
                        plot_maps_dir,
                    )
            else:
                print(f"Warning: No attention snapshots recorded for {ENV_ID_SEARCH}, Seed {SEED_SEARCH}.")

            eval_returns.append(total_return)
            print(f"Evaluation Episode {episode + 1} finished with return: {total_return:.2f}, Steps: {eval_step_count}") # Access the return from the first (and only) env

        eval_envs.close()
        average_return = np.mean(eval_returns)
        print(f"\nAverage return over {num_eval_episodes} episodes: {average_return:.2f}")

        # --- Save Average Return to CSV ---
        csv_filepath = os.path.join(output_dir, "eval_results.csv") # Save in the run directory
        csv_header = ["env_id", "average_return"]

        # Check if file exists to write header
        write_header = not os.path.exists(csv_filepath)

        with open(csv_filepath, 'a') as f:
            if write_header:
                f.write(",".join(csv_header) + "\n")
            f.write(f"{ENV_ID_SEARCH},{average_return:.4f}\n")

        print(f"Average return saved to {csv_filepath}")

        print(f"Finished processing Environment: {ENV_ID_SEARCH}, Seed: {SEED_SEARCH}")
        # --- End Evaluation ---

        print(f"Finished processing Environment: {ENV_ID_SEARCH}, Seed: {SEED_SEARCH}. Check {plot_maps_dir} for attention maps at 10%, 50%, and 90% of the episode.")


print("\nScript finished processing all environments and seeds.")
