from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

from ..utils.seed import set_seed
from ..utils.io import ensure_dir
from ..gridworld.env import LongCorridor
from ..gridworld.agent import sample_trajectory
from ..models.rnn import SeqPredGRU


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=str, default="outputs/checkpoints/best.pt")
    ap.add_argument("--n_traj_eval", type=int, default=100)
    ap.add_argument("--T_eval", type=int, default=50)
    ap.add_argument("--k", type=int, default=3)
    ap.add_argument("--Lx", type=int, default=48)
    ap.add_argument("--Ly", type=int, default=5)
    ap.add_argument("--n_colors", type=int, default=6)
    ap.add_argument("--obs_size", type=int, default=5)
    ap.add_argument("--seed", type=int, default=123)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--outdir", type=str, default="outputs")
    return ap.parse_args()


def main():
    args = parse_args()
    set_seed(args.seed)
    device = torch.device(args.device)

    # Load model args from ckpt
    ckpt = torch.load(args.ckpt, map_location="cpu")
    margs = ckpt.get("args", {})

    # Use the same parameters that were used during training
    env = LongCorridor(Lx=margs.get("Lx", args.Lx), Ly=margs.get("Ly", args.Ly),
                       n_colors=margs.get("n_colors", args.n_colors),
                       obs_size=margs.get("obs_size", args.obs_size),
                       seed=args.seed)

    # Calculate the correct observation dimension based on training parameters
    # +1 for wall channel, +1 for object channel
    obs_dim = (env.obs_size * env.obs_size) * (env.n_colors + 2)
    feat_dim = 10

    # Use the same k value that was used during training
    k_value = margs.get("k", args.k)
    hidden_size = margs.get("hidden", 128)

    model = SeqPredGRU(obs_dim=obs_dim, feat_dim=feat_dim, hidden=hidden_size, k=k_value).to(device)
    model.load_state_dict(ckpt["model"], strict=True)
    model.eval()

    Hlist, Plist, HDFlist = [], [], []
    
    print(f"Running {args.n_traj_eval} trajectories of length {args.T_eval}...")
    with torch.no_grad():
        for traj_idx in range(args.n_traj_eval):
            if (traj_idx + 1) % 20 == 0:
                print(f"Completed {traj_idx + 1}/{args.n_traj_eval} trajectories")
            
            # Sample a new trajectory
            O, A, HDF, P = sample_trajectory(env, T=args.T_eval, seed=args.seed + traj_idx)
            
            # Process all timesteps in this trajectory
            for t in range(args.T_eval - model.k - 1):
                o0 = torch.from_numpy(O[t]).unsqueeze(0).float().to(device)
                feats = torch.from_numpy(np.concatenate([A[t:t+model.k], HDF[t:t+model.k]], axis=1)).unsqueeze(0).float().to(device)
                _, outs = model(o0, feats)  # outs shape: [1, k, hidden_size]
                
                # Collect hidden states from ALL k timesteps in the prediction window
                for k_step in range(model.k):
                    h = outs[0, k_step, :].cpu().numpy()  # Hidden state at k_step
                    Hlist.append(h)
                    # Position and heading at the corresponding timestep t + k_step + 1
                    Plist.append(P[t + k_step + 1])
                    HDFlist.append(HDF[t + k_step + 1])

    print(f"Collected {len(Hlist)} hidden states from {args.n_traj_eval} trajectories")
    H = np.stack(Hlist, axis=0)
    Ppos = np.stack(Plist, axis=0)
    HDF_data = np.stack(HDFlist, axis=0)

    Z = PCA(n_components=2).fit_transform(H)

    outdir = ensure_dir(args.outdir)
    figdir = ensure_dir(Path(outdir) / "figures")

    # Create figure with 3x2 subplots
    fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(15, 18))
    fig.suptitle(f"Results for k={k_value}", fontsize=20)
    
    # PCA colored by x-position
    sc1 = ax1.scatter(Z[:,0], Z[:,1], c=Ppos[:,0], alpha=0.7, cmap='plasma', s=40)
    ax1.set_title("All Hidden States - Colored by x-position")
    ax1.set_xlabel("PC1")
    ax1.set_ylabel("PC2")
    ax1.set_aspect('equal')
    plt.colorbar(sc1, ax=ax1, label="x position")
    
    # PCA colored by y-position
    sc2 = ax2.scatter(Z[:,0], Z[:,1], c=Ppos[:,1], alpha=0.7, cmap='plasma', s=40)
    ax2.set_title("All Hidden States - Colored by y-position")
    ax2.set_xlabel("PC1")
    ax2.set_ylabel("PC2")
    ax2.set_aspect('equal')
    plt.colorbar(sc2, ax=ax2, label="y position")
    
    # PCA colored by head direction
    # Extract heading from HDF (first 4 elements are one-hot encoding)
    headings = np.argmax(HDF_data[:, :4], axis=1)
    sc3 = ax3.scatter(Z[:,0], Z[:,1], c=headings, alpha=0.7, cmap='tab10', s=40)
    ax3.set_title("All Hidden States - Colored by head direction")
    ax3.set_xlabel("PC1")
    ax3.set_ylabel("PC2")
    ax3.set_aspect('equal')
    cbar3 = plt.colorbar(sc3, ax=ax3, label="heading")
    cbar3.set_ticks([0, 1, 2, 3])
    cbar3.set_ticklabels(['North', 'East', 'South', 'West'])
    
    # Loss curves subplot
    train_losses = ckpt.get("train_losses", [])
    val_losses = ckpt.get("val_losses", [])
    
    if train_losses and val_losses:
        epochs = range(1, len(train_losses) + 1)
        ax4.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
        ax4.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
        ax4.set_title("Training and Validation Loss")
        ax4.set_xlabel("Epoch")
        ax4.set_ylabel("Loss")
        ax4.legend()
        ax4.grid(True, alpha=0.3)
    else:
        ax4.text(0.5, 0.5, "No loss data available", 
                ha='center', va='center', transform=ax4.transAxes)
        ax4.set_title("Training and Validation Loss")
    
    # Agent trajectory visualization
    # Sample a single trajectory for visualization
    O_sample, A_sample, HDF_sample, P_sample = sample_trajectory(env, T=min(1000, args.T_eval), seed=args.seed)
    
    # Plot trajectory
    ax5.plot(P_sample[:, 0], P_sample[:, 1], 'b-', alpha=0.7, linewidth=2, label='Trajectory')
    ax5.scatter(P_sample[0, 0], P_sample[0, 1], c='green', s=100, marker='o', label='Start', zorder=5)
    ax5.scatter(P_sample[-1, 0], P_sample[-1, 1], c='red', s=100, marker='s', label='End', zorder=5)
    
    # Add heading arrows every 10 steps
    for i in range(0, len(P_sample), 10):
        if i < len(HDF_sample):
            heading = np.argmax(HDF_sample[i, :4])
            dx, dy = env.dirs[heading]
            ax5.arrow(P_sample[i, 0], P_sample[i, 1], dx*0.3, dy*0.3, 
                     head_width=0.2, head_length=0.2, fc='black', ec='black', alpha=0.6)
    
    ax5.set_xlim(-0.5, env.Lx - 0.5)
    ax5.set_ylim(-0.5, env.Ly - 0.5)
    ax5.set_xlabel("X Position")
    ax5.set_ylabel("Y Position")
    ax5.set_title("Agent Trajectory in Environment")
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    ax5.set_aspect('equal')
    
    # Example visual scene
    # Create a sample observation and visualize it
    mid_idx = len(P_sample)//2
    x_example, y_example = P_sample[mid_idx]  # Middle of trajectory
    h_example = np.argmax(HDF_sample[mid_idx, :4])  # Get heading from HDF
    obs_example = env.egocentric_obs(x_example, y_example, h_example)
    
    # Reshape observation to 2D for visualization
    obs_2d = obs_example.reshape(env.obs_size, env.obs_size, env.n_colors + 2)
    # Convert to RGB-like visualization (use first 3 color channels)
    obs_rgb = np.zeros((env.obs_size, env.obs_size, 3))
    for i in range(env.obs_size):
        for j in range(env.obs_size):
            color_idx = np.argmax(obs_2d[i, j, :])
            if color_idx < env.n_colors:
                # Map color index to RGB
                hue = color_idx / env.n_colors
                obs_rgb[i, j] = [hue, 0.8, 0.8]  # HSV-like coloring
            else:
                obs_rgb[i, j] = [0, 0, 0]  # Wall (black)
    
    im = ax6.imshow(obs_rgb, origin='upper', extent=[-env.obs_size//2, env.obs_size//2, -env.obs_size//2, env.obs_size//2])
    ax6.set_title(f"Example Visual Scene\nPosition: ({x_example}, {y_example}), Heading: {['N','E','S','W'][h_example]}")
    ax6.set_xlabel("Relative X")
    ax6.set_ylabel("Relative Y")
    ax6.grid(True, alpha=0.3)
    
    plt.tight_layout()
    fpath = Path(figdir) / f"hidden_pca_k{k_value}.png"
    plt.savefig(fpath, dpi=300)
    print(f"Saved {fpath}")

if __name__ == "__main__":
    main()
