"""
Evaluate pre-trained BC-GMM (Behavior Cloning with Gaussian Mixture Model) policy.

"""

import os
import numpy as np
import torch
import logging
from tqdm import tqdm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

log = logging.getLogger(__name__)
from util.timer import Timer
from agent.eval.eval_agent import EvalAgent


class EvalBCGMMAgent(EvalAgent):

    def __init__(self, cfg):
        super().__init__(cfg)

    def run(self):

        # Start training loop
        timer = Timer()

        # Prepare video paths for each envs
        options_venv = [{} for _ in range(self.n_envs)]
        if self.render_video:
            for env_ind in range(self.n_render):
                options_venv[env_ind]["video_path"] = os.path.join(
                    self.render_dir, f"eval_trial-{env_ind}.mp4"
                )

        # Reset env before iteration starts
        self.model.eval()
        firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
        prev_obs_venv = self.reset_env_all(options_venv=options_venv)
        firsts_trajs[0] = 1
        reward_trajs = np.zeros((self.n_steps, self.n_envs))
        if self.save_full_observations:
            obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
            obs_full_trajs = np.vstack(
                (obs_full_trajs, prev_obs_venv["state"][:, -1][None])
            )

        # Track end-effector trajectories for 3D plotting
        ee_trajectories = []
        trajectory_success = []
        current_ee_trajectory = [[] for _ in range(self.n_envs)]
        episode_has_initial_pos = [False] * self.n_envs

        # Track episodes for progress bars and success
        episodes_completed = 0
        episodes_succeeded = 0
        current_episode_steps = np.zeros(self.n_envs, dtype=int)
        max_steps_per_episode = self.max_episode_steps // self.act_steps

        # Get target number of episodes
        if self.n_eval_episodes is not None:
            n_eval_episodes = self.n_eval_episodes
            log.info(f"Evaluating {n_eval_episodes} episodes")
        else:
            n_eval_episodes = self.n_envs * (self.n_steps // max_steps_per_episode + 1)

        # Create progress bars
        episode_pbar = tqdm(total=n_eval_episodes, desc="Episodes", position=0, leave=True, unit="ep")
        step_pbar = tqdm(total=max_steps_per_episode, desc="Current episode", position=1, leave=False, unit="step")

        # Collect trajectories from env
        step = 0
        while episodes_completed < n_eval_episodes and step < self.n_steps:
            step += 1

            # Select action using BC-GMM model (deterministic=True uses highest-weighted mode)
            with torch.no_grad():
                cond = {
                    "state": torch.from_numpy(prev_obs_venv["state"])
                    .float()
                    .to(self.device)
                }
                # For BC-GMM, use deterministic=False to sample from GMM
                # This allows exploring different modes
                samples = self.model(cond=cond, deterministic=False)
                output_venv = samples.trajectories.cpu().numpy()
            action_venv = output_venv[:, : self.act_steps]

            # Apply multi-step action
            obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
                self.venv.step(action_venv)
            )
            reward_trajs[step] = reward_venv
            firsts_trajs[step + 1] = terminated_venv | truncated_venv
            if self.save_full_observations:
                obs_full_venv = np.array(
                    [info["full_obs"]["state"] for info in info_venv]
                )
                obs_full_trajs = np.vstack(
                    (obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
                )

            # Extract end-effector positions
            for env_idx in range(self.n_envs):
                if not episode_has_initial_pos[env_idx] and "initial_ee_pos" in info_venv[env_idx]:
                    initial_pos = info_venv[env_idx]["initial_ee_pos"]
                    if initial_pos is not None:
                        initial_pos = np.array(initial_pos).flatten()
                        if initial_pos.shape[0] == 3:
                            current_ee_trajectory[env_idx].append(initial_pos)
                            episode_has_initial_pos[env_idx] = True

                if "all_primitive_ee_positions" in info_venv[env_idx]:
                    all_positions = info_venv[env_idx]["all_primitive_ee_positions"]
                    for ee_pos in all_positions:
                        if ee_pos is not None and hasattr(ee_pos, '__len__'):
                            ee_pos = np.array(ee_pos).flatten()
                            if ee_pos.shape[0] == 3:
                                current_ee_trajectory[env_idx].append(ee_pos)

            # Update episode step tracking
            current_episode_steps += 1

            # Update step progress bar
            avg_steps = int(np.mean(current_episode_steps))
            step_pbar.n = avg_steps
            step_pbar.refresh()

            # Check for episode completion
            done_indices = np.where(terminated_venv | truncated_venv)[0]
            for env_idx in done_indices:
                episodes_completed += 1

                success = terminated_venv[env_idx] and not truncated_venv[env_idx]
                if success:
                    episodes_succeeded += 1

                if len(current_ee_trajectory[env_idx]) > 0:
                    traj_array = np.array(current_ee_trajectory[env_idx])
                    ee_trajectories.append(traj_array)
                    trajectory_success.append(success)
                    current_ee_trajectory[env_idx] = []
                    episode_has_initial_pos[env_idx] = False

                episode_pbar.update(1)
                current_episode_steps[env_idx] = 0

                if self.n_envs == 1:
                    step_pbar.reset()

                if episodes_completed >= n_eval_episodes:
                    break

            if episodes_completed >= n_eval_episodes:
                break

            prev_obs_venv = obs_venv

        # Close progress bars
        episode_pbar.close()
        step_pbar.close()

        # Summarize episode reward
        episodes_start_end = []
        for env_ind in range(self.n_envs):
            env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
            for i in range(len(env_steps) - 1):
                start = env_steps[i]
                end = env_steps[i + 1]
                if end - start > 1:
                    episodes_start_end.append((env_ind, start, end - 1))
        if len(episodes_start_end) > 0:
            reward_trajs_split = [
                reward_trajs[start : end + 1, env_ind]
                for env_ind, start, end in episodes_start_end
            ]
            num_episode_finished = len(reward_trajs_split)
            episode_reward = np.array(
                [np.sum(reward_traj) for reward_traj in reward_trajs_split]
            )
            if self.furniture_sparse_reward:
                episode_best_reward = episode_reward
            else:
                episode_best_reward = np.array(
                    [
                        np.max(reward_traj) / self.act_steps
                        for reward_traj in reward_trajs_split
                    ]
                )
            avg_episode_reward = np.mean(episode_reward)
            avg_best_reward = np.mean(episode_best_reward)
            success_rate = episodes_succeeded / episodes_completed if episodes_completed > 0 else 0
        else:
            episode_reward = np.array([])
            num_episode_finished = 0
            avg_episode_reward = 0
            avg_best_reward = 0
            success_rate = 0
            log.info("[WARNING] No episode completed within the iteration!")

        # Plot state trajectories
        if self.traj_plotter is not None:
            self.traj_plotter(
                obs_full_trajs=obs_full_trajs,
                n_render=self.n_render,
                max_episode_steps=self.max_episode_steps,
                render_dir=self.render_dir,
                itr=0,
            )

        # Plot 3D trajectories
        if len(ee_trajectories) > 0:
            self._plot_3d_trajectories(ee_trajectories, trajectory_success)

        # Log and save metrics
        time = timer()
        avg_best_distance = -avg_best_reward

        log.info(
            f"eval: num episode {episodes_completed:4d} | success rate {success_rate:8.4f} ({episodes_succeeded}/{episodes_completed}) | "
            f"avg episode reward {avg_episode_reward:8.4f} | avg best distance {avg_best_distance*100:.2f}cm"
        )
        np.savez(
            self.result_path,
            num_episode=num_episode_finished,
            eval_success_rate=success_rate,
            eval_episode_reward=avg_episode_reward,
            eval_best_reward=avg_best_reward,
            time=time,
        )

        # Save trajectories
        if len(ee_trajectories) > 0:
            import pickle
            traj_path = self.result_path.replace('.npz', '_trajectories.pkl')
            with open(traj_path, 'wb') as f:
                pickle.dump(ee_trajectories, f)
            log.info(f"Saved {len(ee_trajectories)} trajectories to {traj_path}")

    def _plot_3d_trajectories(self, ee_trajectories, trajectory_success):
        """Plot end-effector trajectories in 3D."""
        fig = plt.figure(figsize=(8, 7))
        ax = fig.add_subplot(111, projection='3d')

        success_color = 'green'
        failure_color = 'red'

        for idx, (trajectory, success) in enumerate(zip(ee_trajectories, trajectory_success)):
            if len(trajectory) > 0:
                x = trajectory[:, 0]
                y = trajectory[:, 1]
                z = trajectory[:, 2]

                color = success_color if success else failure_color

                ax.plot(x, y, z, '-', color=color, linewidth=0.8, alpha=0.4)
                ax.scatter(x[0], y[0], z[0], c='blue', marker='o', s=100,
                          edgecolors='black', linewidths=1, zorder=5, alpha=0.8)
                ax.scatter(x[-1], y[-1], z[-1], c=color, marker='*', s=150,
                          edgecolors='black', linewidths=1, zorder=5)

        ax.set_xlabel('X Position (m)', fontsize=12)
        ax.set_ylabel('Y Position (m)', fontsize=12)
        ax.set_zlabel('Z Position (m)', fontsize=12)
        ax.set_title('End-Effector Trajectories (BC-GMM)', fontsize=14, fontweight='bold')

        from matplotlib.lines import Line2D
        custom_lines = [
            Line2D([0], [0], color='green', linewidth=2, label='Success'),
            Line2D([0], [0], color='red', linewidth=2, label='Failure'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='blue',
                   markersize=10, markeredgecolor='black', markeredgewidth=1, linestyle='None'),
            Line2D([0], [0], marker='*', color='w', markerfacecolor='gray',
                   markersize=12, markeredgecolor='black', markeredgewidth=1, linestyle='None')
        ]
        ax.legend(custom_lines, ['Success', 'Failure', 'Start Point', 'End Point'],
                 loc='upper right', fontsize=10)

        ax.set_box_aspect([1,1,1])

        plot_path = os.path.join(self.render_dir, 'ee_trajectories_3d.png')
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        log.info(f"Saved 3D trajectory plot to {plot_path}")
        plt.close()
