"""
Evaluate parameterized diffusion policy (conditioned on z embeddings).

Extends EvalDiffusionAgent to support z-conditioning.
"""

import os
import sys
import numpy as np
import torch
import logging
from scipy.spatial.distance import cdist
from agent.eval.eval_diffusion_agent import EvalDiffusionAgent


def compute_dtw_distance(traj1, traj2):
    """
    Compute Dynamic Time Warping distance between two trajectories.

    Args:
        traj1: numpy array of shape (T1, D)
        traj2: numpy array of shape (T2, D)

    Returns:
        DTW distance normalized by path length (float)
    """
    n, m = len(traj1), len(traj2)

    # Compute pairwise distances
    cost_matrix = cdist(traj1, traj2, metric='euclidean')

    # Initialize DTW matrix
    dtw = np.full((n + 1, m + 1), np.inf)
    dtw[0, 0] = 0

    # Fill DTW matrix
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = cost_matrix[i - 1, j - 1]
            dtw[i, j] = cost + min(dtw[i - 1, j],      # insertion
                                   dtw[i, j - 1],      # deletion
                                   dtw[i - 1, j - 1])  # match

    # Normalize by path length (same as run_ftl_igm_eval.py)
    return dtw[n, m] / max(n, m)

log = logging.getLogger(__name__)

# Add RLBench directory to path for trajectory encoder
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.join(dppo_root, 'RLBench'))
from RLBench.trajectory_encoder import TrajectoryVAE
from RLBench.train_relative_trajectory_encoder import make_trajectory_relative


class EvalParameterizedDiffusionAgent(EvalDiffusionAgent):
    """
    Evaluation agent for parameterized diffusion policies.

    Config should include:
        - z_list: List of z embeddings to evaluate (or empty list to auto-generate)
        - encoder_checkpoint: Path to trained trajectory encoder
        - dataset_path: Path to normalized dataset (for auto-generating z)
        - n_noise_samples: Number of noise samples per z embedding (default: 2)
    """

    def __init__(self, cfg):
        # Store parameterized config before parent init
        self.z_list = cfg.get('z_list', [])
        self.encoder_checkpoint = cfg.get('encoder_checkpoint', None)
        self.dataset_path = cfg.get('dataset_path', None)
        self.n_noise_samples = cfg.get('n_noise_samples', 2)

        super().__init__(cfg)

        # Auto-generate z embeddings if z_list is empty
        if len(self.z_list) == 0:
            if self.encoder_checkpoint is None or self.dataset_path is None:
                raise ValueError(
                    "If z_list is empty, must provide encoder_checkpoint and dataset_path "
                    "to auto-generate z embeddings from first demo of each mode"
                )
            log.info("z_list is empty - auto-generating 10 z embeddings from first demo of each mode")
            self.z_list = self._generate_z_embeddings()
        else:
            # Convert z_list to torch tensors
            self.z_list = [torch.FloatTensor(z).to(self.device) for z in self.z_list]

        log.info(f"Evaluating with {len(self.z_list)} z embeddings, {self.n_noise_samples} noise samples each")
        log.info(f"Total episodes to evaluate: {len(self.z_list) * self.n_noise_samples}")

    def _generate_z_embeddings(self):
        """
        Auto-generate 10 z embeddings by:
        1. Loading trained trajectory encoder
        2. Finding first demo from each of 10 modes in dataset
        3. Encoding each demo to get z embedding

        Returns:
            List of 10 z embeddings (each is a torch tensor of shape (latent_dim,))
        """
        log.info(f"Loading trajectory encoder from {self.encoder_checkpoint}")

        # Load encoder checkpoint
        checkpoint = torch.load(self.encoder_checkpoint, map_location=self.device, weights_only=False)
        config = checkpoint['config']

        # Initialize encoder
        encoder = TrajectoryVAE(
            state_dim=config['state_dim'],
            action_dim=config['action_dim'],
            hidden_dim=config['hidden_dim'],
            num_layers=config['num_layers'],
            num_heads=config['num_heads'],
            latent_dim=config['latent_dim'],
            horizon=config['horizon']
        ).to(self.device)

        encoder.load_state_dict(checkpoint['model_state_dict'])
        encoder.eval()

        log.info(f"✓ Loaded encoder (latent_dim={config['latent_dim']}, horizon={config['horizon']})")

        # Load dataset
        log.info(f"Loading dataset from {self.dataset_path}")
        data = np.load(self.dataset_path)
        all_states = data['states']  # (N, 22) flattened
        all_actions = data['actions']  # (N, 8) flattened
        traj_lengths = data['traj_lengths']  # (num_episodes,)

        # Load metadata to find first demo of each mode
        # dataset_path is like: .../reach_target/processed/train_normalized.npz
        # metadata_path is like: .../reach_target/variation0/train/train_metadata.npy
        dataset_dir = os.path.dirname(self.dataset_path)  # .../reach_target/processed
        reach_target_dir = os.path.dirname(dataset_dir)  # .../reach_target
        metadata_path = os.path.join(reach_target_dir, 'variation0', 'train', 'train_metadata.npy')
        metadata = np.load(metadata_path, allow_pickle=True)

        log.info(f"Finding first demo from each of 10 modes...")
        log.info(f"Total episodes in dataset: {len(traj_lengths)}")
        log.info(f"Total metadata entries: {len(metadata)}")

        # Extract first demo of each mode
        mode_states = []
        mode_actions = []
        seen_modes = set()

        start_idx = 0
        for episode_idx, traj_len in enumerate(traj_lengths):
            end_idx = start_idx + traj_len
            meta = metadata[episode_idx]

            # Check if this is first demo without noise
            if meta['demo_in_mode'] == 0 and not meta['with_noise']:
                mode_id = meta['mode']

                if mode_id in seen_modes:
                    log.warning(f"  DUPLICATE! Mode {mode_id} already seen, episode {episode_idx}")
                    start_idx = end_idx
                    continue

                seen_modes.add(mode_id)

                states = all_states[start_idx:end_idx]  # (traj_len, 22)
                actions = all_actions[start_idx:end_idx]  # (traj_len, 8)

                mode_states.append(states)
                mode_actions.append(actions)

                log.info(f"  Mode {mode_id}: episode {episode_idx}, {traj_len} steps, control_point={meta.get('control_point', 'N/A')}")

            start_idx = end_idx

            if len(mode_states) == 10:
                break

        log.info(f"Collected {len(mode_states)} unique modes: {sorted(seen_modes)}")
        assert len(mode_states) == 10, f"Expected 10 modes, found {len(mode_states)}"

        # Encode all demos to z embeddings
        mode_states = torch.FloatTensor(np.array(mode_states)).to(self.device)  # (10, 49, 22)
        mode_actions = torch.FloatTensor(np.array(mode_actions)).to(self.device)  # (10, 49, 8)

        # Verify states are different across modes (BEFORE conversion to relative)
        log.info("Checking if ABSOLUTE states are different across modes:")
        log.info(f"  Mode 0, timestep 0, first 4 dims: {mode_states[0, 0, :4].cpu().numpy()}")
        log.info(f"  Mode 1, timestep 0, first 4 dims: {mode_states[1, 0, :4].cpu().numpy()}")
        log.info(f"  Mode 9, timestep 0, first 4 dims: {mode_states[9, 0, :4].cpu().numpy()}")
        log.info(f"  L2 distance between mode 0 and mode 1 states (all timesteps): {torch.norm(mode_states[0] - mode_states[1]).item():.4f}")
        log.info(f"  L2 distance between mode 0 and mode 9 states (all timesteps): {torch.norm(mode_states[0] - mode_states[9]).item():.4f}")

        # Convert to RELATIVE coordinates (same as encoder training)
        log.info("Converting to RELATIVE coordinates (w.r.t. s0)...")
        mode_states_rel, mode_actions_rel = make_trajectory_relative(mode_states, mode_actions)

        # Verify relative states are also different
        log.info("Checking if RELATIVE states are different across modes:")
        log.info(f"  Mode 0, timestep 0 (relative), first 4 dims: {mode_states_rel[0, 0, :4].cpu().numpy()}")
        log.info(f"  Mode 1, timestep 0 (relative), first 4 dims: {mode_states_rel[1, 0, :4].cpu().numpy()}")
        log.info(f"  L2 distance between mode 0 and mode 1 RELATIVE states: {torch.norm(mode_states_rel[0] - mode_states_rel[1]).item():.4f}")

        with torch.no_grad():
            # Get z embeddings using encoder.encode() with RELATIVE coordinates
            # encoder.encode() returns z_mean (unbounded, no tanh)
            z_embeddings = encoder.encode(mode_states_rel, mode_actions_rel)  # (10, latent_dim)

            # Log z values (unbounded z_mean, no tanh)
            log.info("Z embedding values (z_mean, unbounded):")
            for i in range(min(10, z_embeddings.shape[0])):
                log.info(f"  z_mean[{i}]: {z_embeddings[i].cpu().numpy()}")

            log.info("\nL2 distances between z_mean:")
            for i in range(1, min(10, z_embeddings.shape[0])):
                dist = torch.norm(z_embeddings[i] - z_embeddings[i-1]).item()
                log.info(f"  ||z_mean[{i}] - z_mean[{i-1}]|| = {dist:.6f}")

        log.info(f"\n✓ Generated {z_embeddings.shape[0]} z embeddings of dimension {z_embeddings.shape[1]}")

        # Print ALL z embeddings (z_mean, unbounded)
        log.info("Final z embedding values (ALL dimensions):")
        for i in range(min(10, z_embeddings.shape[0])):
            log.info(f"  z[{i}] (mode {i}): {z_embeddings[i].cpu().numpy()}")

        # Also print L2 distances between consecutive z embeddings
        log.info("\nL2 distances between consecutive z embeddings:")
        for i in range(1, min(10, z_embeddings.shape[0])):
            dist = torch.norm(z_embeddings[i] - z_embeddings[i-1]).item()
            log.info(f"  ||z[{i}] - z[{i-1}]|| = {dist:.6f}")

        # Convert to list of individual tensors
        return [z_embeddings[i] for i in range(z_embeddings.shape[0])]

    def run(self):
        """
        Run evaluation for all z embeddings.

        Strategy: Set current_z before calling parent's run(), which will evaluate
        n_eval_episodes. We track which z each episode corresponds to.
        """
        # For simplicity, just run parent's evaluation with each z sequentially
        # by setting model.current_z before each episode batch

        log.info(f"Starting evaluation of {len(self.z_list)} z embeddings x {self.n_noise_samples} noise samples each")

        # Collect all trajectories for final plotting
        self.all_eval_trajectories = []
        self.all_eval_success = []
        self.all_eval_z_ids = []

        for z_idx, z_embedding in enumerate(self.z_list):
            log.info(f"\n{'='*80}")
            log.info(f"Evaluating z embedding {z_idx + 1}/{len(self.z_list)}")
            log.info(f"{'='*80}")

            # Store z in model (shape: (1, latent_dim) for batch size 1)
            self.model.current_z = z_embedding.unsqueeze(0)
            self.current_z_id = z_idx  # For tracking in the parent's method
            log.info(f"  Set model.current_z = z[{z_idx}]: {z_embedding[:4].cpu().numpy()}")

            # Temporarily override n_eval_episodes for this z
            original_n_eval = self.n_eval_episodes
            self.n_eval_episodes = self.n_noise_samples

            # Run parent's evaluation for this z
            super().run()

            # Restore original n_eval_episodes
            self.n_eval_episodes = original_n_eval

        log.info(f"\n{'='*80}")
        log.info(f"Completed evaluation of all {len(self.z_list)} z embeddings")
        log.info(f"{'='*80}")

        # Plot all trajectories together
        if len(self.all_eval_trajectories) > 0:
            # Remove current_z_id so _plot_3d_trajectories knows to plot, not collect
            delattr(self, 'current_z_id')
            self._plot_3d_trajectories(self.all_eval_trajectories, self.all_eval_success)

    def _plot_3d_trajectories(self, ee_trajectories, trajectory_success):
        """
        Override parent's method to either collect trajectories or plot all at once.

        During parent's run(): Collect trajectories for later
        After all z's evaluated: Plot 2 subplots comparing eval vs demo
        """
        # Check if we're collecting (during parent's run) or plotting (after all z's done)
        if hasattr(self, 'current_z_id'):
            # Collecting phase - store trajectories with z_id
            for traj, success in zip(ee_trajectories, trajectory_success):
                self.all_eval_trajectories.append(traj)
                self.all_eval_success.append(success)
                self.all_eval_z_ids.append(self.current_z_id)
            return  # Don't plot yet

        # Plotting phase - create final 2-subplot visualization
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D

        fig = plt.figure(figsize=(16, 7))

        # Use tab10 colormap for 10 modes
        colors = plt.cm.tab10(np.arange(10))

        # Fixed positions
        fixed_start = np.array([0.27851078, -0.00815551, 1.4719069])
        fixed_target = np.array([0.36239344, -0.12145063, 1.11076617])

        # Load demo trajectories first (needed for DTW calculation)
        mode_demo_trajectories = self._load_mode_demos()

        # LEFT SUBPLOT: Evaluation trajectories (20 trajectories, 2 per mode)
        ax1 = fig.add_subplot(121, projection='3d')

        # Compute DTW distances between eval and demo trajectories
        dtw_distances = []
        for idx, (trajectory, success, z_id) in enumerate(zip(ee_trajectories, trajectory_success, self.all_eval_z_ids)):
            if len(trajectory) > 0 and z_id < len(mode_demo_trajectories):
                demo_traj = mode_demo_trajectories[z_id]
                if len(demo_traj) > 0:
                    dtw_dist = compute_dtw_distance(trajectory, demo_traj)
                    dtw_distances.append(dtw_dist)

        # Calculate average DTW
        avg_dtw = np.mean(dtw_distances) if dtw_distances else 0.0
        log.info(f"Average DTW distance: {avg_dtw:.4f}")

        ax1.set_title(f'Evaluation Trajectories (2 per mode)\nAvg DTW: {avg_dtw:.4f}', fontsize=14, fontweight='bold')

        # Group trajectories by z_id
        for idx, (trajectory, success, z_id) in enumerate(zip(ee_trajectories, trajectory_success, self.all_eval_z_ids)):
            color = colors[z_id]

            if len(trajectory) > 0:
                x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
                alpha = 0.8 if success else 0.4
                linewidth = 2 if success else 1

                ax1.plot(x, y, z, '-', color=color, linewidth=linewidth, alpha=alpha)
                ax1.scatter(x[-1], y[-1], z[-1], c=color, marker='*' if success else 'x',
                           s=100, alpha=alpha, edgecolors='black', linewidths=0.5)

        # Add start and target
        ax1.scatter(*fixed_start, color='green', s=200, marker='o', label='Start',
                   edgecolors='black', linewidths=2, zorder=10)
        ax1.scatter(*fixed_target, color='gold', s=400, marker='*', label='Target',
                   edgecolors='black', linewidths=2, zorder=10)

        ax1.set_xlabel('X (m)', fontsize=11)
        ax1.set_ylabel('Y (m)', fontsize=11)
        ax1.set_zlabel('Z (m)', fontsize=11)
        ax1.legend(fontsize=10)
        ax1.grid(True, alpha=0.3)
        ax1.set_box_aspect([1,1,1])

        # RIGHT SUBPLOT: Demo trajectories (10 demos, one per mode)
        ax2 = fig.add_subplot(122, projection='3d')
        ax2.set_title('Ground Truth Demos (10 modes)', fontsize=14, fontweight='bold')

        for mode_id in range(10):
            color = colors[mode_id]
            demo_traj = mode_demo_trajectories[mode_id]  # (49, 3)

            if len(demo_traj) > 0:
                x, y, z = demo_traj[:, 0], demo_traj[:, 1], demo_traj[:, 2]
                ax2.plot(x, y, z, '-', color=color, linewidth=2, alpha=0.8, label=f'Mode {mode_id}')
                ax2.scatter(x[-1], y[-1], z[-1], c=color, marker='*', s=100,
                           edgecolors='black', linewidths=0.5)

        # Add start and target
        ax2.scatter(*fixed_start, color='green', s=200, marker='o', label='Start',
                   edgecolors='black', linewidths=2, zorder=10)
        ax2.scatter(*fixed_target, color='gold', s=400, marker='*', label='Target',
                   edgecolors='black', linewidths=2, zorder=10)

        ax2.set_xlabel('X (m)', fontsize=11)
        ax2.set_ylabel('Y (m)', fontsize=11)
        ax2.set_zlabel('Z (m)', fontsize=11)
        ax2.legend(fontsize=8, ncol=2, loc='upper right')
        ax2.grid(True, alpha=0.3)
        ax2.set_box_aspect([1,1,1])

        # Match view angles
        ax1.view_init(elev=20, azim=45)
        ax2.view_init(elev=20, azim=45)

        plt.tight_layout()
        plot_path = os.path.join(self.render_dir, 'ee_trajectories_3d.png')
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        log.info(f"Saved 3D trajectory comparison plot to {plot_path}")
        plt.close()

    def _load_mode_demos(self):
        """Load RAW end-effector positions from the 10 mode demonstrations."""
        import pickle
        from pathlib import Path

        # Load metadata to find first demo of each mode
        dataset_dir = os.path.dirname(self.dataset_path)
        reach_target_dir = os.path.dirname(dataset_dir)
        metadata_path = os.path.join(reach_target_dir, 'variation0', 'train', 'train_metadata.npy')
        metadata = np.load(metadata_path, allow_pickle=True)

        # Path to raw episodes
        episodes_dir = Path(reach_target_dir) / 'variation0' / 'train' / 'episodes'

        mode_ee_trajectories = []

        for episode_idx, meta in enumerate(metadata):
            if meta['demo_in_mode'] == 0 and not meta['with_noise']:
                # Load raw episode data
                episode_dir = episodes_dir / f'episode{episode_idx}'
                demo_path = episode_dir / 'low_dim_obs.pkl'

                if demo_path.exists():
                    with open(demo_path, 'rb') as f:
                        demo = pickle.load(f)

                    # Extract end-effector positions from observations
                    ee_positions = []
                    for obs in demo:
                        # gripper_pose is [x, y, z, qx, qy, qz, qw]
                        ee_positions.append(obs.gripper_pose[:3])

                    mode_ee_trajectories.append(np.array(ee_positions))

            if len(mode_ee_trajectories) == 10:
                break

        return mode_ee_trajectories
