import logging
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
import sys
import plotly.graph_objects as go # Plotly import

# --- Imports from the project, similar to SOn_runner and MD_runner ---
from runners.Basic_runner import BasicRunner
from utils import split_dataset, check_memory

# --- Imports for the GeneralRunner's specific functionalities ---
from manifolds.general import Manifold_general
from sklearn.decomposition import PCA
from scipy.stats import wasserstein_distance

# Add the parent directory to the path to allow importing 'constraints'
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from constraints import get_constraint_functions
from torch.func import vmap

class GeneralRunner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)
        self.load_data()

        # --- Exhibit dataset, similar to SOn_runner ---
        # --- Fit PCA on the true data once for later evaluation ---
        logging.info("Fitting PCA on the true dataset for visualization...")
        test_set_samples = self.data_set[:self.config.sample.sample_num]
        self.test_set_samples_np = test_set_samples.cpu().numpy()

        self.pca = PCA(n_components=2)
        self.pca.fit(self.test_set_samples_np)
        self.true_data_pca = self.pca.transform(self.test_set_samples_np)

        # Plot the PCA of the true data itself as a reference
        self.plot_pca_comparison(self.test_set_samples_np, savefig='true_data')
        logging.info("PCA model fitted and reference plot for true data saved.")

    def load_data(self):
        # --- Logic to load data, similar to other runners ---
        dataset_path = os.path.join("./data/general/", f"{self.config.problem.dataset}.npy")
        logging.info(f"Loading dataset from: {dataset_path}")
        data_ori = torch.from_numpy(np.load(dataset_path)).float()
        self.data_set = data_ori[torch.randperm(data_ori.shape[0])].clone()
        self.training_set, self.test_set, self.val_set = split_dataset(self.data_set, self.config.seed)

        # --- Initialize the Manifold_general by dynamically loading constraints ---
        logging.info("Initializing General Manifold with dynamic constraints...")
        h_func, g_func = get_constraint_functions(self.config.problem.dataset)

        self.manifold = Manifold_general(
            dim=self.config.problem.dim,
            m=self.config.problem.m,
            l=self.config.problem.l,
            h=h_func,
            g=g_func,
            boundary_repulsion_rate=self.config.sample.epsilon
        )

        # --- Generate forward path for training, similar to SOn_runner ---
        if self.config.if_train or self.config.if_sample:
            self.training_set_path, _ = self.generate_path_dataset(self.training_set, keep_quiet=False)
            check_memory(self.training_set_path)

    def sample_on_manifolds(self):
        # --- This method is a blend of SOn_runner and MD_runner's logic ---
        logging.info(f'Start sampling on manifolds.')
        device = self.device
        if self.network is not None: self.network.to(device)

        # --- Part 1: Get initial samples from a prior distribution ---
        prior_path = os.path.join(self.samples_dir, f"{self.config.problem.dataset}_prior.npy")

        if os.path.exists(prior_path):
            logging.info(f"Loading pre-generated prior from {prior_path}")
            init_samples = torch.from_numpy(np.load(prior_path)).to(device).float()
        else:
            logging.warning(f"Prior file not found at {prior_path}. Generating it now...")
            start_data = self.training_set[:self.config.sample.sample_num].to(device)
            
            _, x_hist, _ = self.SDE_sampler_manifolds(
                self.sde, self.manifold, start_data,
                reverse=False, score_net=None,
                keep_quiet=False, n_steps=self.config.sample.get('forward_steps', self.sde.N),
                **self.sde_kwargs
            )
            init_samples = x_hist[-1]
            np.save(prior_path, init_samples.cpu().numpy())
            logging.info(f"Saved newly generated prior to {prior_path}")

        # --- Part 2: Run the backward SDE from the prior to generate final samples ---
        logging.info("Start sampling backward SDE from the prior.")
        x, x_hist, other_dict = self.SDE_sampler_manifolds(self.sde, self.manifold, init_samples,
                                                        reverse=True,
                                                        score_net=self.network,
                                                        keep_quiet=False, **self.sde_kwargs)

        # --- Part 3: Evaluate and plot the final generated samples ---
        self.evaluate_and_plot(x, self.config.training.n_epochs, "generated_final")

        # --- Part 4: Save the final samples ---
        save_path = os.path.join(self.samples_dir, f"{self.config.problem.dataset}_generated.npy")
        np.save(save_path, x.cpu().numpy())
        logging.info(f"Saved final generated samples to {save_path}")


    def evaluate_and_plot(self, samples, epoch, save_name):
        # --- Generic evaluation and plotting function using PCA ---
        samples_np = samples.cpu().numpy()
        generated_pca = self.pca.transform(samples_np)

        # Quantitative Evaluation: Wasserstein Distance
        w2_pc1 = wasserstein_distance(self.true_data_pca[:, 0], generated_pca[:, 0])
        w2_pc2 = wasserstein_distance(self.true_data_pca[:, 1], generated_pca[:, 1])
        avg_w2 = (w2_pc1 + w2_pc2) / 2
        logging.info(f"Average W2 Distance on PCA for '{save_name}': {avg_w2:.6f}")
        self.tb_logger.add_scalar('eval/W2_distance', avg_w2, global_step=epoch)

        # Visual Evaluation: PCA Scatter Plot
        self.plot_pca_comparison(samples_np, savefig=save_name)

    def plot_pca_comparison(self, statistics, savefig=None, compare=True):
        # --- Plotting function, similar to SOn_runner's plot_hist ---
        samples_pca = self.pca.transform(statistics)
        fig, ax = plt.subplots(figsize=(10, 8))
        if compare:
            ax.scatter(self.true_data_pca[:, 0], self.true_data_pca[:, 1], s=10, alpha=0.2, label='True Data', c='blue')
            ax.scatter(samples_pca[:, 0], samples_pca[:, 1], s=10, alpha=0.5, label='Generated', c='red')
            ax.legend()
        else:
            ax.scatter(samples_pca[:, 0], samples_pca[:, 1], s=10, alpha=0.5)
        ax.set_title(f'PCA Projection Comparison ({savefig})')
        ax.set_xlabel('Principal Component 1')
        ax.set_ylabel('Principal Component 2')
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.set_aspect('equal', 'box')
        plt.savefig(os.path.join(self.savefig_dir, f"PCA_{savefig}.png"), dpi=300)
        plt.close(fig)

    def plot_3d_html_comparison(self, samples, epoch, save_name):
        """Generates an interactive 3D plot for sphere_mog validation."""
        logging.info(f"Generating interactive 3D plot for epoch {epoch}...")
        samples_np = samples.cpu().numpy()

        # 1. Create sphere surface grid
        u, v = np.mgrid[0:2*np.pi:100j, 0:np.pi:50j]
        x_sphere = np.cos(u)*np.sin(v)
        y_sphere = np.sin(u)*np.sin(v)
        z_sphere = np.cos(v)
        
        # 2. Color sphere surface based on g(x)
        sphere_points = torch.tensor(np.stack([x_sphere.ravel(), y_sphere.ravel(), z_sphere.ravel()], axis=-1), dtype=torch.float32)
        with torch.no_grad():
            g_values = vmap(self.manifold.g_single)(sphere_points).numpy().ravel()
        
        # Define colors: transparent cyan for allowed, translucent red for forbidden
        colors = np.array([[0, 1, 1, 0.1]] * len(g_values)) # Allowed
        colors[g_values > 0] = [1, 0, 0, 0.3] # Forbidden
        
        # 3. Create Plotly figure
        fig = go.Figure()
        
        # Add Sphere surface
        fig.add_trace(go.Surface(x=x_sphere, y=y_sphere, z=z_sphere, 
                                 surfacecolor=g_values.reshape(x_sphere.shape), # Color by g(x)
                                 colorscale=[[0, 'rgba(0,255,255,0.1)'], [1, 'rgba(255,0,0,0.3)']], # cyan -> red
                                 cmin=0, cmax=g_values.max() if g_values.max() > 0 else 1,
                                 showscale=False,
                                 name='Manifold Constraint'))

        # Add True Data
        fig.add_trace(go.Scatter3d(x=self.test_set_samples_np[:, 0], y=self.test_set_samples_np[:, 1], z=self.test_set_samples_np[:, 2],
                                   mode='markers', marker=dict(size=2, color='red', opacity=0.5), name='True Data'))

        # Add Generated Samples
        fig.add_trace(go.Scatter3d(x=samples_np[:, 0], y=samples_np[:, 1], z=samples_np[:, 2],
                                   mode='markers', marker=dict(size=2.5, color='blue'), name='Generated Samples'))

        # Update layout
        fig.update_layout(title=f'3D Validation at Epoch {epoch} ({save_name})',
                          scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z',
                                     aspectmode='data'),
                          margin=dict(l=0, r=0, b=0, t=40))
        
        # Save to HTML file
        save_path = os.path.join(self.savefig_dir, f"{save_name}_3d.html")
        fig.write_html(save_path)
        logging.info(f"Saved interactive 3D plot to {save_path}")

    def validate(self, mode=None, epoch=0, **kwargs):
        # --- Validation loop with constraint violation logging ---
        if mode == 'end' or epoch % self.config.training.val_freq != 0 or epoch == 0:
            return

        logging.info(f"--- Starting validation at epoch {epoch} ---")
        init_samples = torch.randn(self.config.sample.sample_num, self.manifold.dim, device=self.device)
        
        # Generate samples
        x, _, _ = self.SDE_sampler_manifolds(self.sde, self.manifold, init_samples,
                                            reverse=True, score_net=self.network,
                                            keep_quiet=True, **self.sde_kwargs)
        
        # --- Calculate and Log Constraint Violations ---
        if self.manifold.m > 0:
            h_values = self.manifold.h(x)
            mean_h_violation = torch.abs(h_values).mean()
            logging.info(f"  Validation Equality Violation (mean |h(x)|): {mean_h_violation.item():.6f}")

        if self.manifold.l > 0:
            g_values = self.manifold.g(x)
            # Violation occurs when g(x) > 0. Use relu to isolate positive values.
            g_violations = torch.relu(g_values)
            max_g_violation = g_violations.max()
            mean_g_violation = g_violations.mean()
            logging.info(f"  Validation Inequality Violation (max g(x)): {max_g_violation.item():.6f}")
            logging.info(f"  Validation Inequality Violation (mean g(x)): {mean_g_violation.item():.6f}")

        # --- Perform Visualizations ---
        self.evaluate_and_plot(x, epoch, f'val_{epoch}_generated')
        
        # Add 3D plot only for the sphere_mog dataset if it's 3D
        if self.dataset_name == 'sphere_mog' and self.manifold.dim == 3:
            self.plot_3d_html_comparison(x, epoch, f'val_{epoch}')

        logging.info(f"--- Finished validation at epoch {epoch} ---")