"""
Evaluate pre-trained/DPPO-fine-tuned diffusion policy.

"""

import os
import numpy as np
import torch
import logging
from tqdm import tqdm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

log = logging.getLogger(__name__)
from util.timer import Timer
from agent.eval.eval_agent import EvalAgent


class EvalDiffusionAgent(EvalAgent):

    def __init__(self, cfg):
        super().__init__(cfg)

    def run(self):

        # Start training loop
        timer = Timer()

        # Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
        options_venv = [{} for _ in range(self.n_envs)]
        if self.render_video:
            for env_ind in range(self.n_render):
                options_venv[env_ind]["video_path"] = os.path.join(
                    self.render_dir, f"eval_trial-{env_ind}.mp4"
                )

        # Reset env before iteration starts
        self.model.eval()
        firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
        prev_obs_venv = self.reset_env_all(options_venv=options_venv)
        firsts_trajs[0] = 1
        reward_trajs = np.zeros((self.n_steps, self.n_envs))
        if self.save_full_observations:  # state-only
            obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
            obs_full_trajs = np.vstack(
                (obs_full_trajs, prev_obs_venv["state"][:, -1][None])
            )

        # Track end-effector trajectories for 3D plotting
        ee_trajectories = []  # List of trajectories, one per episode
        trajectory_success = []  # Track whether each trajectory succeeded
        current_ee_trajectory = [[] for _ in range(self.n_envs)]  # Current trajectory per env
        episode_has_initial_pos = [False] * self.n_envs  # Track if we've added initial pos for current episode

        # Track episodes for progress bars and success
        episodes_completed = 0
        episodes_succeeded = 0  # Track actual RLBench success (terminated, not truncated)
        current_episode_steps = np.zeros(self.n_envs, dtype=int)
        max_steps_per_episode = self.max_episode_steps // self.act_steps

        # Get target number of episodes
        if self.n_eval_episodes is not None:
            n_eval_episodes = self.n_eval_episodes
            log.info(f"Evaluating {n_eval_episodes} episodes")
        else:
            n_eval_episodes = self.n_envs * (self.n_steps // max_steps_per_episode + 1)

        # Create progress bars
        episode_pbar = tqdm(total=n_eval_episodes, desc="Episodes", position=0, leave=True, unit="ep")
        step_pbar = tqdm(total=max_steps_per_episode, desc="Current episode", position=1, leave=False, unit="step")

        # Collect a set of trajectories from env
        step = 0
        while episodes_completed < n_eval_episodes and step < self.n_steps:
            step += 1

            # Select action
            with torch.no_grad():
                cond = {
                    "state": torch.from_numpy(prev_obs_venv["state"])
                    .float()
                    .to(self.device)
                }
                samples = self.model(cond=cond, deterministic=True)
                output_venv = (
                    samples.trajectories.cpu().numpy()
                )  # n_env x horizon x act
            action_venv = output_venv[:, : self.act_steps]

            # Apply multi-step action
            obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
                self.venv.step(action_venv)
            )
            reward_trajs[step] = reward_venv
            firsts_trajs[step + 1] = terminated_venv | truncated_venv
            if self.save_full_observations:  # state-only
                obs_full_venv = np.array(
                    [info["full_obs"]["state"] for info in info_venv]
                )  # n_envs x act_steps x obs_dim
                obs_full_trajs = np.vstack(
                    (obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
                )

            # Extract end-effector positions at primitive timestep granularity
            # multi_step wrapper collects all primitive EE positions from all actions
            # multi_step executes act_steps=4 actions, each with action_repeat=4 primitive steps
            # So we get up to 16 primitive positions per diffusion call (or fewer if episode ends early)
            for env_idx in range(self.n_envs):
                # Add initial EE position if this is the first step of the episode
                if not episode_has_initial_pos[env_idx] and "initial_ee_pos" in info_venv[env_idx]:
                    initial_pos = info_venv[env_idx]["initial_ee_pos"]
                    if initial_pos is not None:
                        initial_pos = np.array(initial_pos).flatten()  # Ensure 1D array
                        if initial_pos.shape[0] == 3:  # Validate it's a 3D position
                            current_ee_trajectory[env_idx].append(initial_pos)
                            episode_has_initial_pos[env_idx] = True

                # Get all primitive EE positions from this multi-step
                if "all_primitive_ee_positions" in info_venv[env_idx]:
                    all_positions = info_venv[env_idx]["all_primitive_ee_positions"]
                    for ee_pos in all_positions:
                        # Ensure ee_pos is a numpy array of the right shape
                        if ee_pos is not None and hasattr(ee_pos, '__len__'):
                            ee_pos = np.array(ee_pos).flatten()  # Ensure 1D array
                            if ee_pos.shape[0] == 3:  # Validate it's a 3D position
                                current_ee_trajectory[env_idx].append(ee_pos)

            # Update episode step tracking
            current_episode_steps += 1

            # Update step progress bar
            avg_steps = int(np.mean(current_episode_steps))
            step_pbar.n = avg_steps
            step_pbar.refresh()

            # Check for episode completion and save trajectories
            done_indices = np.where(terminated_venv | truncated_venv)[0]
            for env_idx in done_indices:
                episodes_completed += 1

                # Track RLBench success: terminated (sensor triggered) vs truncated (timeout)
                success = terminated_venv[env_idx] and not truncated_venv[env_idx]
                if success:
                    episodes_succeeded += 1

                if len(current_ee_trajectory[env_idx]) > 0:
                    traj_array = np.array(current_ee_trajectory[env_idx])
                    ee_trajectories.append(traj_array)
                    trajectory_success.append(success)
                    current_ee_trajectory[env_idx] = []  # Reset for next episode
                    episode_has_initial_pos[env_idx] = False  # Reset initial pos flag for next episode

                # Update episode progress bar
                episode_pbar.update(1)

                # Reset tracking for this environment
                current_episode_steps[env_idx] = 0

                # Reset step progress bar for new episode
                if self.n_envs == 1:
                    step_pbar.reset()

                # Check if we've reached target number of episodes
                if episodes_completed >= n_eval_episodes:
                    break

            # Break outer loop if we've reached target
            if episodes_completed >= n_eval_episodes:
                break

            # update for next step
            prev_obs_venv = obs_venv

        # Close progress bars
        episode_pbar.close()
        step_pbar.close()

        # Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
        episodes_start_end = []
        for env_ind in range(self.n_envs):
            env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
            for i in range(len(env_steps) - 1):
                start = env_steps[i]
                end = env_steps[i + 1]
                if end - start > 1:
                    episodes_start_end.append((env_ind, start, end - 1))
        if len(episodes_start_end) > 0:
            reward_trajs_split = [
                reward_trajs[start : end + 1, env_ind]
                for env_ind, start, end in episodes_start_end
            ]
            num_episode_finished = len(reward_trajs_split)
            episode_reward = np.array(
                [np.sum(reward_traj) for reward_traj in reward_trajs_split]
            )
            if (
                self.furniture_sparse_reward
            ):  # only for furniture tasks, where reward only occurs in one env step
                episode_best_reward = episode_reward
            else:
                episode_best_reward = np.array(
                    [
                        np.max(reward_traj) / self.act_steps
                        for reward_traj in reward_trajs_split
                    ]
                )
            avg_episode_reward = np.mean(episode_reward)
            avg_best_reward = np.mean(episode_best_reward)
            # Use actual RLBench success (proximity sensor) instead of reward threshold
            success_rate = episodes_succeeded / episodes_completed if episodes_completed > 0 else 0
            # Old method (reward threshold - wrong for RLBench):
            # success_rate = np.mean(episode_best_reward >= self.best_reward_threshold_for_success)
        else:
            episode_reward = np.array([])
            num_episode_finished = 0
            avg_episode_reward = 0
            avg_best_reward = 0
            success_rate = 0
            log.info("[WARNING] No episode completed within the iteration!")

        # Plot state trajectories (only in D3IL)
        if self.traj_plotter is not None:
            self.traj_plotter(
                obs_full_trajs=obs_full_trajs,
                n_render=self.n_render,
                max_episode_steps=self.max_episode_steps,
                render_dir=self.render_dir,
                itr=0,
            )

        # Plot 3D trajectories
        if len(ee_trajectories) > 0:
            self._plot_3d_trajectories(ee_trajectories, trajectory_success)

        # Log loss and save metrics
        time = timer()
        # Calculate additional metrics for RLBench
        avg_best_distance = -avg_best_reward  # Convert negative reward back to distance (meters)

        log.info(
            f"eval: num episode {episodes_completed:4d} | RLBench success rate {success_rate:8.4f} ({episodes_succeeded}/{episodes_completed}) | "
            f"avg episode reward {avg_episode_reward:8.4f} | avg best distance {avg_best_distance*100:.2f}cm"
        )
        np.savez(
            self.result_path,
            num_episode=num_episode_finished,
            eval_success_rate=success_rate,
            eval_episode_reward=avg_episode_reward,
            eval_best_reward=avg_best_reward,
            time=time,
        )

        # Save trajectories separately as a pickle file
        if len(ee_trajectories) > 0:
            import pickle
            traj_path = self.result_path.replace('.npz', '_trajectories.pkl')
            with open(traj_path, 'wb') as f:
                pickle.dump(ee_trajectories, f)
            log.info(f"Saved {len(ee_trajectories)} trajectories to {traj_path}")

    def _plot_3d_trajectories(self, ee_trajectories, trajectory_success):
        """Plot end-effector trajectories in 3D with 2 colors: success (green) and failure (red)."""
        fig = plt.figure(figsize=(8, 7))  # Smaller figure size
        ax = fig.add_subplot(111, projection='3d')

        # Use only 2 colors: green for success, red for failure
        success_color = 'green'
        failure_color = 'red'

        for idx, (trajectory, success) in enumerate(zip(ee_trajectories, trajectory_success)):
            if len(trajectory) > 0:
                x = trajectory[:, 0]
                y = trajectory[:, 1]
                z = trajectory[:, 2]

                # Choose color based on success
                color = success_color if success else failure_color
                label = 'Success' if success else 'Failure'

                # Plot trajectory line
                ax.plot(x, y, z, '-', color=color, linewidth=0.8, alpha=0.4)

                # Mark start point (blue) and end point (same as trajectory color)
                ax.scatter(x[0], y[0], z[0], c='blue', marker='o', s=100,
                          edgecolors='black', linewidths=1, zorder=5, alpha=0.8)
                ax.scatter(x[-1], y[-1], z[-1], c=color, marker='*', s=150,
                          edgecolors='black', linewidths=1, zorder=5)

        # Set labels and title
        ax.set_xlabel('X Position (m)', fontsize=12)
        ax.set_ylabel('Y Position (m)', fontsize=12)
        ax.set_zlabel('Z Position (m)', fontsize=12)
        ax.set_title('End-Effector Trajectories (Multimodal Diffusion Policy)', fontsize=14, fontweight='bold')

        # Add legend for trajectory types and markers
        from matplotlib.lines import Line2D
        custom_lines = [
            Line2D([0], [0], color='green', linewidth=2, label='Success'),
            Line2D([0], [0], color='red', linewidth=2, label='Failure'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='blue',
                   markersize=10, markeredgecolor='black', markeredgewidth=1, linestyle='None'),
            Line2D([0], [0], marker='*', color='w', markerfacecolor='gray',
                   markersize=12, markeredgecolor='black', markeredgewidth=1, linestyle='None')
        ]
        ax.legend(custom_lines, ['Success', 'Failure', 'Start Point', 'End Point'],
                 loc='upper right', fontsize=10)

        # Set equal aspect ratio for better visualization
        ax.set_box_aspect([1,1,1])

        # Add fixed target point (large gold star)
        fixed_target = np.array([0.36239344, -0.12145063, 1.11076617])
        ax.scatter(fixed_target[0], fixed_target[1], fixed_target[2],
                  c='gold', marker='*', s=500, edgecolors='black',
                  linewidths=2, zorder=10, label='Target')

        # Add success sphere (5cm radius)
        u = np.linspace(0, 2 * np.pi, 30)
        v = np.linspace(0, np.pi, 20)
        x_sphere = 0.05 * np.outer(np.cos(u), np.sin(v)) + fixed_target[0]
        y_sphere = 0.05 * np.outer(np.sin(u), np.sin(v)) + fixed_target[1]
        z_sphere = 0.05 * np.outer(np.ones(np.size(u)), np.cos(v)) + fixed_target[2]
        ax.plot_surface(x_sphere, y_sphere, z_sphere, color='gold', alpha=0.15, zorder=1)

        # Save plot as PNG
        plot_path = os.path.join(self.render_dir, 'ee_trajectories_3d.png')
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')  # Lower DPI for smaller file size
        log.info(f"Saved 3D trajectory plot to {plot_path}")
        plt.close()

        # Save interactive HTML plot using plotly
        try:
            import plotly.graph_objects as go

            fig_interactive = go.Figure()

            for idx, (trajectory, success) in enumerate(zip(ee_trajectories, trajectory_success)):
                if len(trajectory) > 0:
                    x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
                    color = 'green' if success else 'red'

                    fig_interactive.add_trace(go.Scatter3d(
                        x=x, y=y, z=z, mode='lines',
                        line=dict(color=color, width=2),
                        name='Success' if (success and idx == 0) else ('Failure' if (not success and idx == sum(trajectory_success)) else None),
                        showlegend=bool(idx == 0 or (not success and idx == sum(trajectory_success))),
                        opacity=0.6
                    ))

                    if idx == 0:
                        fig_interactive.add_trace(go.Scatter3d(
                            x=[x[0]], y=[y[0]], z=[z[0]], mode='markers',
                            marker=dict(size=5, color='blue'), name='Start', showlegend=True
                        ))

            fixed_target = np.array([0.36239344, -0.12145063, 1.11076617])
            fig_interactive.add_trace(go.Scatter3d(
                x=[fixed_target[0]], y=[fixed_target[1]], z=[fixed_target[2]],
                mode='markers', marker=dict(size=15, color='gold', symbol='diamond'),
                name='Target'
            ))

            fig_interactive.update_layout(
                title='End-Effector Trajectories (Interactive 3D)',
                scene=dict(xaxis_title='X (m)', yaxis_title='Y (m)', zaxis_title='Z (m)', aspectmode='cube'),
                width=900, height=700
            )

            # Show interactive plot (opens in browser)
            fig_interactive.show()
            log.info("Opened interactive 3D plot in browser")
        except ImportError:
            log.info("Plotly not available - skipping interactive HTML plot")
