import torch
import numpy as np
import logging
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Circle
from tqdm import tqdm
from sklearn.decomposition import PCA

from runners.Basic_runner import BasicRunner
from manifolds.Robot import Manifold_Robot, forward_kinematics_pytorch_batched
from utils import split_dataset, check_memory

class RobotPrior:
    def __init__(self, prior_sample):
        self.prior_sample = torch.from_numpy(prior_sample).float()
    def prior_sampler(self, n):
        indices = torch.randint(0, self.prior_sample.shape[0], (n,))
        return self.prior_sample[indices]

class RobotRunner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.load_data()

        logging.info("Fitting PCA on the true training dataset (14D) for visualization...")
        self.pca = PCA(n_components=2)
        self.pca.fit(self.training_set.cpu().numpy())

        logging.info("Visualizing initial training set and forward process.")
        self.plot_sample_hist(self.training_set[:self.config.sample.sample_num].cpu().numpy(), savefig="training_set")
        if self.config.problem.prior_mode == 'load' and self.prior is not None:
            self.plot_sample_hist(self.prior.prior_sample[:self.config.sample.sample_num].cpu().numpy(), savefig="prior_dist")
        
        if self.config.if_train or self.config.if_sample:
            x_hist = self.training_set_path[:self.config.sample.sample_num].clone().transpose(0, 1)
            x = x_hist[-1].cpu().numpy()
            if self.config.problem.if_plot_fwd:
                self.plot_sample_hist(x, savefig='forward_end')
                plot_idx = list(range(10)) + list(range(10, 101, 10))
                for i in range(self.sde.N + 1):
                    if (100 * i / self.sde.N in plot_idx) or (i < 5):
                        x_temp = x_hist[i].cpu().numpy()
                        self.plot_sample_hist(x_temp, savefig=f'generating_fwd_{i}')
                np.save(f"{self.samples_dir}/{self.dataset_name}_hist_fwd.npy", x_hist.cpu().detach().numpy())

    def load_data(self):
        data_path = './data/robot_arm/'
        paths = np.load(f'{data_path}robot_7dof_joints_paths.npy')
        labels = np.load(f'{data_path}robot_7dof_joints_labels.npy')

        data_ori = torch.tensor(paths, dtype=torch.float32).reshape(paths.shape[0], -1)
        self.full_data = data_ori.clone()
        labels_ori = torch.tensor(labels, dtype=torch.float32).reshape(labels.shape[0], 1)
        full_dataset = torch.cat([data_ori, labels_ori], dim=1)
        full_dataset = full_dataset[torch.randperm(full_dataset.shape[0])]

        self.training_set, self.test_set, self.val_set = split_dataset(full_dataset, self.config.seed)
        self.training_labels = self.training_set[:, -1:]; self.training_set = self.training_set[:, :-1]
        self.test_labels = self.test_set[:, -1:]; self.test_set = self.test_set[:, :-1]
        self.val_labels = self.val_set[:, -1:]; self.val_set = self.val_set[:, :-1]
        
        # self._check_constraints(self.training_set, "Training Set")

        self.prior = None
        if self.config.problem.prior_mode == 'load':
            logging.info("Using 'load' mode for prior. Loading from file.")
            prior_paths = f'{data_path}robot_7dof_joints_prior.npy'
            prior_sample_np = np.load(prior_paths).reshape(-1, self.config.problem.time_steps * 14)
            self.prior = RobotPrior(prior_sample=prior_sample_np)
            # self._check_constraints(self.prior.prior_sample, "Prior Set")

        elif self.config.problem.prior_mode == 'generate':
            logging.info("Using 'generate' mode...")
        else:
            raise ValueError(f"Invalid prior_mode: {self.config.problem.prior_mode}.")

        if self.config.if_train or self.config.if_sample:
            logging.info("Generating forward path dataset...")
            self.training_set_path, _ = self.generate_path_dataset(self.training_set, keep_quiet=False)
            check_memory(self.training_set_path)

    def _check_constraints(self, dataset, dataset_name):
        if not hasattr(self, 'manifold'):
             logging.warning("Manifold not initialized, skipping constraint check.")
             return
        logging.info(f"--- Checking constraints for {dataset_name} ({len(dataset)} samples) ---")
        if len(dataset) == 0: return
        h_func, g_func = self.manifold.h, self.manifold.g
        if h_func is None and g_func is None:
            logging.info("No constraints defined. Skipping check.")
            return

        dataset_tensor = dataset.to(self.device)
        with torch.no_grad():
            if h_func is not None:
                h_vals = h_func(dataset_tensor)
                overall_max_h = h_vals.abs().max().item()
                logging.info(f"Max absolute equality violation max|h(x)|: {overall_max_h:.6f}")
                logging.info(" -> SATISFIED." if overall_max_h < 1e-4 else " -> VIOLATED.")
            if g_func is not None:
                g_vals = g_func(dataset_tensor)
                overall_max_g = g_vals.max().item()
                logging.info(f"Max inequality violation max g(x): {overall_max_g:.6f}")
                logging.info(" -> SATISFIED." if overall_max_g <= 1e-5 else " -> VIOLATED.")
        logging.info("-" * 45)

    def _get_initial_samples(self, n, labels):
        if self.config.problem.prior_mode == 'load':
            return self.prior.prior_sampler(n).to(self.device)
        elif self.config.problem.prior_mode == 'generate':
            indices = torch.randint(0, self.training_set.shape[0], (n,))
            init_from_data = self.training_set[indices].to(self.device)
            generated_samples, _, _ = self.SDE_sampler_manifolds(
                self.sde, self.manifold, init_from_data,
                reverse=True, score_net=self.network,
                labels=labels, keep_quiet=True, **self.sde_kwargs
            )
            return generated_samples

    def _get_3d_path(self, trajectory_14d_flat):
        """
        [MODIFIED] Converts a flattened 14D (cos, sin) trajectory to a 3D end-effector path.
        """
        if isinstance(trajectory_14d_flat, np.ndarray):
            trajectory_14d_flat = torch.from_numpy(trajectory_14d_flat).float()
        
        # Reshape to (time_steps, 14)
        q_cos_sin = trajectory_14d_flat.to(self.device).reshape(self.config.problem.time_steps, 14)
        
        # Convert from (cos, sin) representation back to angles (theta)
        cos_part = q_cos_sin[..., :7]
        sin_part = q_cos_sin[..., 7:]
        q_points_theta = torch.arctan2(sin_part, cos_part) # Shape: (time_steps, 7)
        
        with torch.no_grad():
            # Use the converted 7D angles for forward kinematics
            link_positions = forward_kinematics_pytorch_batched(q_points_theta)
            path_3d = link_positions[self.manifold.end_effector_link_index]
        
        return path_3d.cpu().numpy()

    def plot_topdown_comparison(self, true_samples, gen_samples, savefig=None):
        """
        Plot Top-Down Trajectory View comparing true vs generated trajectories.
        """
        fig_path = f"{self.savefig_dir}/TopDown_{savefig}.pdf"
        plt.figure(figsize=(8, 8))
        ax = plt.gca()

        # True trajectories (light green)
        for traj_flat in true_samples:
            path_3d = self._get_3d_path(traj_flat)
            ax.plot(path_3d[:, 0], path_3d[:, 1], color="lightgreen", alpha=0.3)

        # Generated trajectories (salmon / light red)
        for traj_flat in gen_samples:
            path_3d = self._get_3d_path(traj_flat)
            ax.plot(path_3d[:, 0], path_3d[:, 1], color="salmon", alpha=0.3)

        # Obstacles (black)
        if self.config.problem.get("obstacles_info"):
            for obs in self.config.problem.obstacles_info:
                ax.add_patch(Circle((obs["position"][0], obs["position"][1]),
                                    radius=0.1, color="black", zorder =5))

        # Labels & Titles
        ax.set_title("Top-Down Trajectory View", fontsize=18, pad=15)
        ax.set_xlabel("X position", fontsize=16)
        ax.set_ylabel("Y position", fontsize=16)
        ax.tick_params(axis="both", which="major", labelsize=14)
        ax.grid(True, alpha=0.4)
        ax.set_aspect("equal", adjustable="box")

        # Legend
        legend_elements = [
            Line2D([0], [0], color="lightgreen", lw=2, label="True Trajectories"),
            Line2D([0], [0], color="salmon", lw=2, label="Generated Trajectories"),
            Circle((0, 0), 0.1, color="black", label="Obstacles"),
        ]
        ax.legend(handles=legend_elements, fontsize=14, loc="upper right", frameon=False)

        plt.tight_layout()
        plt.savefig(fig_path, bbox_inches="tight")
        plt.close()



    def plot_sample_hist(self, samples, savefig=None):
        fig_path = f"{self.savefig_dir}/Hist_{savefig}.pdf"

        plt.figure(figsize=(8, 8))
        ax = plt.gca()

        # Trajectories
        for traj_flat in samples:
            path_3d = self._get_3d_path(traj_flat)
            midpoint_x = path_3d[len(path_3d) // 2, 0]
            color = "gold" if midpoint_x > 0.4 else "magenta"
            ax.plot(path_3d[:, 0], path_3d[:, 1], color=color, alpha=0.2)

        # Obstacles
        if self.config.problem.get("obstacles_info"):
            for obs in self.config.problem.obstacles_info:
                ax.add_patch(Circle((obs["position"][0], obs["position"][1]),
                                    radius=0.1, color="green"))

        # Labels & Title (bigger fonts)
        ax.set_title(f"Top-Down Trajectory View ({len(samples)} samples)", fontsize=18, pad=15)
        ax.set_xlabel("X position", fontsize=16)
        ax.set_ylabel("Y position", fontsize=16)
        ax.tick_params(axis="both", which="major", labelsize=14)
        ax.grid(True, alpha=0.4)
        ax.set_aspect("equal", adjustable="box")

        # Legend (bigger, fixed loc)
        legend_elements = [
            Line2D([0], [0], color="magenta", lw=2, label="c=0"),
            Line2D([0], [0], color="gold", lw=2, label="c=1"),
            Circle((0, 0), 0.1, color="green", label="Obstacles"),
        ]
        ax.legend(handles=legend_elements, fontsize=14, loc="upper right", frameon=False)

        # Save as PDF with tight layout
        plt.tight_layout()
        plt.savefig(fig_path, bbox_inches="tight")
        plt.close()


    def plot_pca_comparison(self, samples, savefig=None):
        if not hasattr(self, 'pca'): return
        samples_np = samples.detach().cpu().numpy() if isinstance(samples, torch.Tensor) else samples
        samples_pca = self.pca.transform(samples_np)
        true_data_pca = self.pca.transform(self.training_set.cpu().numpy())
        fig, ax = plt.subplots(figsize=(10, 8))
        ax.scatter(true_data_pca[:, 0], true_data_pca[:, 1], s=15, alpha=0.3, label='True Data Projection', c='blue')
        ax.scatter(samples_pca[:, 0], samples_pca[:, 1], s=15, alpha=0.5, label='Generated Data Projection', c='red')
        ax.set_title(f'PCA Projection Comparison ({savefig})'); ax.set_xlabel('PC 1'); ax.set_ylabel('PC 2')
        ax.grid(True, linestyle='--'); ax.legend()
        plt.savefig(self.savefig_dir + f"/PCA_Plot_{savefig}.png", dpi=300); plt.close(fig)

    def validate(self, mode=None, epoch=0, **kwargs):
        if mode == 'start' or mode == 'end': return
        logging.info(f"--- Start validating: Epoch {epoch} ---")
        val_indices = torch.randint(0, len(self.val_set), (self.config.sample.sample_num,))
        labels = self.val_labels[val_indices].to(self.device)
        init = self._get_initial_samples(self.config.sample.sample_num, labels)
        x, _, _ = self.SDE_sampler_manifolds(self.sde, self.manifold, init,
                                             reverse=True, score_net=self.network, labels=labels,
                                             keep_quiet=True, **self.sde_kwargs)
        x_np = x.cpu().numpy()
        self.plot_sample_hist(x_np, savefig=f'val_{epoch}_generated')
        self.plot_pca_comparison(x_np, savefig=f'val_{epoch}_generated')
        self.plot_topdown_comparison(self.full_data, x_np, savefig=f'val_{epoch}_comparison')

        logging.info(f"--- End validating. Plots saved for epoch {epoch}. ---")

    def test(self): pass

    def sample_on_manifolds(self):
        logging.info('Start sampling on manifolds.')
        if self.network is not None: self.network.to(self.device)
        labels_c0 = torch.zeros(self.config.sample.sample_num // 2, 1, device=self.device)
        init_c0 = self._get_initial_samples(len(labels_c0), labels_c0)
        x_c0, x_hist_c0, _ = self.SDE_sampler_manifolds(self.sde, self.manifold, init_c0,
                                                      reverse=True, score_net=self.network, labels=labels_c0,
                                                      keep_quiet=False, **self.sde_kwargs)
        labels_c1 = torch.ones(self.config.sample.sample_num // 2, 1, device=self.device)
        init_c1 = self._get_initial_samples(len(labels_c1), labels_c1)
        x_c1, x_hist_c1, _ = self.SDE_sampler_manifolds(self.sde, self.manifold, init_c1,
                                                      reverse=True, score_net=self.network, labels=labels_c1,
                                                      keep_quiet=False, **self.sde_kwargs)
        x = torch.cat([x_c0, x_c1], dim=0)
        x_np, x_hist = x.cpu().numpy(), torch.cat([x_hist_c0, x_hist_c1], dim=1).cpu().numpy()
        self.plot_sample_hist(x_np, savefig='generated_final')
        self.plot_pca_comparison(x_np, savefig='generated_final') 
        plot_idx = list(range(0, 100, 10)) + list(range(90, 101))
        for i in range(self.sde.N + 1):
            if (100 * i / self.sde.N in plot_idx) or (i > self.sde.N - 5):
                self.plot_sample_hist(x_hist[i], savefig=f'generating_bwd_{i}')
        np.save(f"{self.samples_dir}/{self.dataset_name}_samples_generated.npy", x_np)
        np.save(f"{self.samples_dir}/{self.dataset_name}_hist_bwd.npy", x_hist)