"""
Evaluate styles parameterized diffusion policy (conditioned on z embeddings).

Key differences from EvalParameterizedDiffusionAgent:
1. Uses canonical coordinate transformation (relative to BOTH start and end)
2. Passes start and end positions to the model
3. Handles styles dataset metadata (start_pos, end_pos, style)
"""

import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
import logging
from scipy.spatial.distance import cdist
from agent.eval.eval_diffusion_agent import EvalDiffusionAgent

log = logging.getLogger(__name__)

# Add paths for imports
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(dppo_root)
sys.path.append(os.path.join(dppo_root, 'RLBench_reach'))
sys.path.append(os.path.join(dppo_root, 'RLBench_reach', 'encoder'))


def compute_dtw_distance(traj1, traj2):
    """Compute Dynamic Time Warping distance between two trajectories."""
    n, m = len(traj1), len(traj2)
    cost_matrix = cdist(traj1, traj2, metric='euclidean')
    dtw = np.full((n + 1, m + 1), np.inf)
    dtw[0, 0] = 0
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = cost_matrix[i - 1, j - 1]
            dtw[i, j] = cost + min(dtw[i - 1, j], dtw[i, j - 1], dtw[i - 1, j - 1])
    return dtw[n, m] / max(n, m)


def make_trajectory_relative_canonical(states, actions, length_norm=True, eps=1e-4):
    """
    Convert trajectories to canonical coordinates (relative to BOTH start and end).

    Same as in train_style_encoder.py and style_sequence.py.
    """
    states_rel = states.clone()
    actions_rel = actions.clone()

    B, T, _ = states.shape
    device = states.device

    # Joint space: relative to s0
    s0_joint_pos = states[:, 0:1, 0:7]
    s0_joint_vel = states[:, 0:1, 7:14]
    states_rel[:, :, 0:7] = states[:, :, 0:7] - s0_joint_pos
    states_rel[:, :, 7:14] = states[:, :, 7:14] - s0_joint_vel
    actions_rel[:, :, 0:7] = actions[:, :, 0:7] - s0_joint_pos

    # Canonicalize EE xyz w.r.t (Start, Target) frame
    S = states[:, 0, 15:18]      # (B, 3) - EE start
    E = states[:, -1, 15:18]     # (B, 3) - EE end (target)

    v = E - S
    L = torch.norm(v, dim=-1, keepdim=True)
    L_safe = torch.clamp(L, min=eps)

    e1 = v / L_safe  # unit direction S -> E

    # Choose up vector
    up = torch.zeros(B, 3, device=device)
    up[:, 2] = 1.0
    parallel = (torch.abs((e1 * up).sum(dim=-1)) > 0.9)
    up[parallel, 0] = 1.0
    up[parallel, 2] = 0.0

    # Orthonormal basis
    e2 = torch.cross(e1, up, dim=-1)
    e2 = F.normalize(e2, dim=-1)
    e3 = torch.cross(e1, e2, dim=-1)

    # Project EE positions into canonical frame
    ee = states[:, :, 15:18]
    rel = ee - S.unsqueeze(1)

    x = (rel * e1.unsqueeze(1)).sum(dim=-1, keepdim=True)
    y = (rel * e2.unsqueeze(1)).sum(dim=-1, keepdim=True)
    z = (rel * e3.unsqueeze(1)).sum(dim=-1, keepdim=True)

    if length_norm:
        x = x / L_safe.unsqueeze(1)
        y = y / L_safe.unsqueeze(1)
        z = z / L_safe.unsqueeze(1)

    states_rel[:, :, 15:18] = torch.cat([x, y, z], dim=-1)

    return states_rel, actions_rel


class EvalStylesDiffusionAgent(EvalDiffusionAgent):
    """
    Evaluation agent for styles parameterized diffusion policies.

    Key differences from EvalParameterizedDiffusionAgent:
    - Uses canonical coordinate transformation
    - Passes start and end positions to the model
    - Handles styles dataset with multiple styles per target
    """

    def __init__(self, cfg):
        # Store styles-specific config before parent init
        self.z_list = cfg.get('z_list', [])
        self.encoder_checkpoint = cfg.get('encoder_checkpoint', None)
        self.dataset_path = cfg.get('dataset_path', None)
        self.metadata_path = cfg.get('metadata_path', None)
        self.n_noise_samples = cfg.get('n_noise_samples', 2)
        self.n_styles = cfg.get('n_styles', 3)
        self.cfg_guidance_scale = cfg.get('cfg_guidance_scale', 1.0)

        super().__init__(cfg)

        # Auto-generate z embeddings if z_list is empty
        if len(self.z_list) == 0:
            if self.encoder_checkpoint is None or self.dataset_path is None:
                raise ValueError(
                    "If z_list is empty, must provide encoder_checkpoint and dataset_path"
                )
            log.info("z_list is empty - auto-generating z embeddings from demos")
            self.z_list, self.demo_info = self._generate_z_embeddings()
        else:
            self.z_list = [torch.FloatTensor(z).to(self.device) for z in self.z_list]
            self.demo_info = []

        log.info(f"Evaluating {len(self.z_list)} z embeddings, {self.n_noise_samples} noise samples each")

    def _load_encoder(self):
        """Load the style encoder using importlib to avoid import conflicts."""
        import importlib.util

        encoder_module_path = os.path.join(dppo_root, 'RLBench_reach', 'encoder', 'trajectory_encoder.py')
        spec = importlib.util.spec_from_file_location("trajectory_encoder_styles", encoder_module_path)
        trajectory_encoder_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(trajectory_encoder_module)
        TrajectoryVAE = trajectory_encoder_module.TrajectoryVAE

        log.info(f"Loading style encoder from {self.encoder_checkpoint}")
        checkpoint = torch.load(self.encoder_checkpoint, map_location=self.device, weights_only=False)
        config = checkpoint['config']

        encoder = TrajectoryVAE(
            state_dim=config['state_dim'],
            action_dim=config['action_dim'],
            hidden_dim=config['hidden_dim'],
            num_layers=config['num_layers'],
            num_heads=config['num_heads'],
            latent_dim=config['latent_dim'],
            horizon=config['horizon']
        ).to(self.device)

        encoder.load_state_dict(checkpoint['model_state_dict'])
        encoder.eval()

        log.info(f"Loaded encoder (latent_dim={config['latent_dim']}, horizon={config['horizon']})")
        return encoder

    def _generate_z_embeddings(self):
        """
        Generate z embeddings from demo trajectories using canonical coordinates.

        For styles dataset:
        - Sample representative trajectories (different targets and styles)
        - Encode using canonical transformation
        """
        encoder = self._load_encoder()

        # Load dataset
        log.info(f"Loading dataset from {self.dataset_path}")
        data = np.load(self.dataset_path)
        all_states = data['states']
        all_actions = data['actions']
        traj_lengths = data['traj_lengths']

        # Load metadata
        log.info(f"Loading metadata from {self.metadata_path}")
        metadata = np.load(self.metadata_path, allow_pickle=True)

        log.info(f"Total episodes: {len(traj_lengths)}, metadata entries: {len(metadata)}")

        # Collect demos: sample a few targets, all styles per target
        demo_states = []
        demo_actions = []
        demo_info = []  # (target_idx, style, start_pos, end_pos)

        # Sample targets to evaluate (e.g., first 5 targets out of 27)
        n_targets_to_eval = min(5, len(set(m['target_idx'] for m in metadata)))
        seen_combinations = set()

        start_idx = 0
        for episode_idx, traj_len in enumerate(traj_lengths):
            end_idx = start_idx + traj_len

            if episode_idx >= len(metadata):
                break

            meta = metadata[episode_idx]
            if isinstance(meta, np.void):
                meta = dict(meta)

            target_idx = meta.get('target_idx', 0)
            style = meta.get('style', 'unknown')
            start_pos = np.array(meta.get('start_pos', [0, 0, 0]))
            end_pos = np.array(meta.get('end_pos', [0, 0, 0]))

            # Only use first few targets, one per (target, style) combination
            if target_idx < n_targets_to_eval:
                combo_key = (target_idx, style)
                if combo_key not in seen_combinations:
                    seen_combinations.add(combo_key)

                    states = all_states[start_idx:end_idx]
                    actions = all_actions[start_idx:end_idx]

                    demo_states.append(states)
                    demo_actions.append(actions)
                    demo_info.append({
                        'target_idx': target_idx,
                        'style': style,
                        'start_pos': start_pos,
                        'end_pos': end_pos,
                        'episode_idx': episode_idx
                    })

                    log.info(f"  Added target={target_idx}, style={style}, episode={episode_idx}")

            start_idx = end_idx

            # Stop if we have enough
            if len(demo_states) >= n_targets_to_eval * self.n_styles:
                break

        log.info(f"Collected {len(demo_states)} demo trajectories")

        # Convert to tensors
        demo_states = torch.FloatTensor(np.array(demo_states)).to(self.device)
        demo_actions = torch.FloatTensor(np.array(demo_actions)).to(self.device)

        # Convert to canonical coordinates
        log.info("Converting to canonical coordinates...")
        states_rel, actions_rel = make_trajectory_relative_canonical(demo_states, demo_actions)

        # Encode to get z embeddings
        with torch.no_grad():
            z_embeddings = encoder.encode(states_rel, actions_rel)

        log.info(f"Generated {z_embeddings.shape[0]} z embeddings of dim {z_embeddings.shape[1]}")

        # Log z values
        for i, info in enumerate(demo_info):
            log.info(f"  z[{i}] (target={info['target_idx']}, style={info['style']}): "
                    f"{z_embeddings[i, :4].cpu().numpy()}...")

        return [z_embeddings[i] for i in range(z_embeddings.shape[0])], demo_info

    def run(self):
        """Run evaluation for all z embeddings."""
        log.info(f"Starting evaluation of {len(self.z_list)} z embeddings")

        self.all_eval_trajectories = []
        self.all_eval_success = []
        self.all_eval_z_ids = []

        for z_idx, z_embedding in enumerate(self.z_list):
            info = self.demo_info[z_idx] if z_idx < len(self.demo_info) else {}
            log.info(f"\n{'='*80}")
            log.info(f"Evaluating z[{z_idx}]: target={info.get('target_idx', '?')}, style={info.get('style', '?')}")
            log.info(f"{'='*80}")

            # Set z in model
            self.model.current_z = z_embedding.unsqueeze(0)
            self.current_z_id = z_idx

            # Set start/end positions in model for conditioning
            if info:
                start_pos = torch.FloatTensor(info['start_pos']).to(self.device)
                end_pos = torch.FloatTensor(info['end_pos']).to(self.device)
                self.model.current_start = start_pos.unsqueeze(0)
                self.model.current_end = end_pos.unsqueeze(0)
                log.info(f"  Set start={info['start_pos']}, end={info['end_pos']}")

            # Run evaluation for this z
            original_n_eval = self.n_eval_episodes
            self.n_eval_episodes = self.n_noise_samples
            super().run()
            self.n_eval_episodes = original_n_eval

        log.info(f"\n{'='*80}")
        log.info(f"Completed evaluation of {len(self.z_list)} z embeddings")
        log.info(f"{'='*80}")

        # Final plot
        if len(self.all_eval_trajectories) > 0:
            delattr(self, 'current_z_id')
            self._plot_3d_trajectories(self.all_eval_trajectories, self.all_eval_success)

    def _plot_3d_trajectories(self, ee_trajectories, trajectory_success):
        """Plot evaluation trajectories grouped by style/target."""
        if hasattr(self, 'current_z_id'):
            # Collecting phase
            for traj, success in zip(ee_trajectories, trajectory_success):
                self.all_eval_trajectories.append(traj)
                self.all_eval_success.append(success)
                self.all_eval_z_ids.append(self.current_z_id)
            return

        # Plotting phase
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D

        fig = plt.figure(figsize=(12, 8))
        ax = fig.add_subplot(111, projection='3d')

        colors = plt.cm.tab10(np.arange(10))
        style_markers = {'quadratic_bezier': 'o', 's_curve_cubic': 's', 'piecewise_linear': '^'}

        for idx, (traj, success, z_id) in enumerate(zip(ee_trajectories, trajectory_success, self.all_eval_z_ids)):
            info = self.demo_info[z_id] if z_id < len(self.demo_info) else {}
            target_idx = info.get('target_idx', 0)
            style = info.get('style', 'unknown')

            color = colors[target_idx % 10]
            marker = style_markers.get(style, 'o')

            if len(traj) > 0:
                x, y, z = traj[:, 0], traj[:, 1], traj[:, 2]
                alpha = 0.8 if success else 0.4
                ax.plot(x, y, z, '-', color=color, linewidth=2, alpha=alpha)
                ax.scatter(x[-1], y[-1], z[-1], c=[color], marker=marker,
                          s=100, alpha=alpha, edgecolors='black')

        ax.set_xlabel('X (m)')
        ax.set_ylabel('Y (m)')
        ax.set_zlabel('Z (m)')
        ax.set_title('Styles Evaluation Trajectories')
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plot_path = os.path.join(self.render_dir, 'ee_trajectories_3d.png')
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        log.info(f"Saved trajectory plot to {plot_path}")
        plt.close()
