import logging
import matplotlib.pyplot as plt
import torch
import numpy as np
from runners.Basic_runner import BasicRunner
from utils import (
    split_dataset,
    check_memory)
# Import PCA from scikit-learn
from sklearn.decomposition import PCA
from metric import frobenius_norm_of_jacobian
from scipy.spatial.distance import jensenshannon


class SOnRunner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)
        self.load_data()

        """---------------------------------------exhibit dataset----------------------------------------"""
        test_set_samples = self.data_set[:self.config.sample.sample_num]
        self.test_set_samples = test_set_samples.clone().detach()
        self.test_set_statistics = self.get_statistics(test_set_samples).cpu().numpy()
        self.plot_hist(self.test_set_statistics, savefig="true", compare=False)

        uniform = self.manifold.uniform_sample(self.config.sample.sample_num)
        uniform_statistics = self.get_statistics(uniform).numpy()
        self.plot_hist(uniform_statistics, savefig="uniform", compare=False)

        logging.info("Fitting PCA on the true dataset for visualization...")
        self.pca = PCA(n_components=2)
        self.pca.fit(self.test_set_samples.cpu().numpy())
        self.plot_pca_comparison(self.test_set_samples, savefig="true_data")
        
        if self.config.if_train or self.config.if_sample:
            x_hist = self.training_set_path[:self.config.sample.sample_num].clone().transpose(0,1)
            if self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA', 'ULLA_OABOA']:
                x_hist = x_hist[:,:,:self.manifold.out_dim]
                # Truncate momentum dimension to match manifold output dimension

            x = x_hist[-1]
            x = self.filter_sample(x)
            statistics = self.get_statistics(x).cpu().numpy()
            self.plot_hist(statistics, savefig='forward_end')
            plot_idx = list(range(10)) + list(range(10, 101, 10))
            for i in range(self.sde.N+1):
                if (100 * i / self.sde.N in plot_idx) or (i < 5):
                    x_temp = x_hist[i].clone()
                    statistics = self.get_statistics(x_temp).cpu().numpy()
                    self.plot_hist(statistics, savefig=f'generating_fwd_{i}')
            
            statistics_path = self.get_statistics_path(x_hist.cpu().detach()).numpy()
            np.save(f"{self.samples_dir}/{self.dataset_name}_statistics_fwd.npy", statistics_path)

    def load_data(self):
        self.power_list = [1, 2, 4, 5]

        data_ori = torch.tensor(np.load(f"./data/SOn/{self.dataset_name}.npy")).reshape(-1, self.manifold.out_dim)
        self.data_set = data_ori[torch.randperm(data_ori.shape[0])].clone()
        self.training_set, self.test_set, self.val_set = split_dataset(self.data_set, self.config.seed)

        if self.config.if_train or self.config.if_sample:
            self.training_set_path, h_val = self.generate_path_dataset(self.training_set, keep_quiet=False)
            check_memory(self.training_set_path)

    def filter_sample(self, samples):
        all_mat = samples.reshape(-1, self.manifold.mat_dim, self.manifold.mat_dim)
        manifolds_idx = torch.where(torch.linalg.det(all_mat) > 0)[0]
        logging.info(f"The number of samples on the correct connected component: {manifolds_idx.shape[0]}/{samples.shape[0]}, the others are dropped.")
        return samples[manifolds_idx]
    
    def get_statistics(self, samples):
        if not isinstance(samples, torch.Tensor): samples = torch.tensor(samples)

        samples = samples.reshape(-1, self.manifold.mat_dim, self.manifold.mat_dim)
        trace_list = []
        for i in range(4):
            samples_pow = torch.matrix_power(samples, self.power_list[i])
            trace = samples_pow.diagonal(dim1=-2, dim2=-1).sum(dim=-1, keepdim=True)
            trace_list.append(trace)
        return torch.cat(trace_list, dim=1)
    
    def get_statistics_path(self, path):
        statistics = torch.zeros(path.shape[0], path.shape[1], 4)
        for i in range(path.shape[0]):
            temp = self.get_statistics(path[i])
            statistics[i] = temp.clone()
        return statistics

    def plot_hist(self, statistics, savefig=None, compare=True):
        bins = int(statistics.shape[0]/100) if statistics.shape[0] > 100 else 10
        fig = plt.figure(figsize=(10, 10))
        for i in range(4):
            ax = plt.subplot(2, 2, i + 1)
            if compare:
                ax.hist(statistics[:, i], bins=bins, histtype='stepfilled', alpha=0.5, density=True, color='green', label='Generated')
                ax.hist(self.test_set_statistics[:, i], bins=bins, histtype='stepfilled', alpha=0.5, density=True, color='red', label='True')
                ax.legend()
            else:
                ax.hist(statistics[:, i], bins=bins, alpha=1.0, density=True)
            ax.set_title(rf"Histogram of $tr(S^{self.power_list[i]})$")

        plt.savefig(self.savefig_dir + f"/Hist_Statistics_{savefig}.png", dpi=300)
        plt.close(fig)

    def calculate_and_log_jsd(self, generated_stats, bins=100):
        """
        Calculates and logs the Jensen-Shannon Divergence between the true
        and generated trace statistics.
        """
        logging.info("--- Calculating Jensen-Shannon Divergence (JSD) ---")
        true_stats = self.test_set_statistics
        jsd_scores = []

        for i in range(true_stats.shape[1]):
            # Define common bins for a fair comparison
            min_val = min(true_stats[:, i].min(), generated_stats[:, i].min())
            max_val = max(true_stats[:, i].max(), generated_stats[:, i].max())
            bin_edges = np.linspace(min_val, max_val, bins + 1)

            # Create histograms (probability distributions)
            true_hist, _ = np.histogram(true_stats[:, i], bins=bin_edges, density=True)
            generated_hist, _ = np.histogram(generated_stats[:, i], bins=bin_edges, density=True)

            # Add a small epsilon to avoid zero probabilities
            epsilon = 1e-10
            true_hist += epsilon
            generated_hist += epsilon

            # Normalize to ensure they sum to 1
            true_hist /= true_hist.sum()
            generated_hist /= generated_hist.sum()

            # Calculate Jensen-Shannon Divergence
            jsd = jensenshannon(true_hist, generated_hist)
            jsd_scores.append(jsd)
            logging.info(f"INFO: JSD for tr(S^{self.power_list[i]}): {jsd:.6f}")

        avg_jsd = np.mean(jsd_scores)
        logging.info(f"INFO: Average JSD across all statistics: {avg_jsd:.6f}")
        logging.info("--------------------------------------------------")
        return avg_jsd

    def plot_pca_comparison(self, samples, savefig=None):
        if not hasattr(self, 'pca'):
            logging.warning("PCA model not found. Skipping PCA plot.")
            return

        if isinstance(samples, torch.Tensor):
            samples_np = samples.detach().cpu().numpy()
        else:
            samples_np = samples
            
        samples_pca = self.pca.transform(samples_np)
        true_data_pca = self.pca.transform(self.test_set_samples.cpu().numpy())

        fig, ax = plt.subplots(figsize=(10, 8))
        
        ax.scatter(true_data_pca[:, 0], true_data_pca[:, 1], 
                   s=15, alpha=0.3, label='True Data Projection', c='blue')
        ax.scatter(samples_pca[:, 0], samples_pca[:, 1], 
                   s=15, alpha=0.5, label='Generated Data Projection', c='red')

        ax.set_title(f'PCA Projection Comparison ({savefig})', fontsize=16)
        ax.set_xlabel('Principal Component 1', fontsize=12)
        ax.set_ylabel('Principal Component 2', fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.legend()
        
        if 'circle' in self.dataset_name:
              ax.set_aspect('equal', 'box')

        plt.savefig(self.savefig_dir + f"/PCA_Plot_{savefig}.png", dpi=300)
        plt.close(fig)

    def validate(self, mode=None, epoch=0, **kwargs):
        if mode == 'start':
            self.best_nll_K_val = torch.inf
        elif mode == 'end':
            pass
        else:
            logging.info(f"-------------------------Start validating: Epoch {epoch}-------------------------")
            nll_K_train, nll_train, loss_part, loss_OBA_part, loss_BO_part = self.negative_log_likelihood_fn(kwargs["batch"])
            self.tb_logger.add_scalar('nll_train', nll_train, global_step=epoch)
            self.tb_logger.add_scalar('nll_K_train', nll_K_train, global_step=epoch)
            logging.info(f'nll_K_train:{nll_K_train:.4f}, nll_train: {nll_train:.4f}, loss_train: {loss_part:.4f}, loss_OBA_train: {loss_OBA_part:.4f}, loss_BO_train: {loss_BO_part:.4f}.')

            print(f"Start sampling at epoch {epoch}.")
            init = self.manifold.uniform_sample(self.config.sample.sample_num).to(self.device)
            x, x_hist, other_dict = self.SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                                reverse=True,
                                                                score_net=self.network,
                                                                keep_quiet=True, **self.sde_kwargs)
            if self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA', 'ULLA_OABOA']:
                x = x[:,:self.manifold.out_dim]
                
            x = self.filter_sample(x)
            
            statistics_gen = self.get_statistics(x).cpu().numpy()
            self.plot_hist(statistics_gen, savefig=f'val_{epoch}_generated')
            
            self.calculate_and_log_jsd(statistics_gen)
            self.plot_pca_comparison(x, savefig=f'val_{epoch}_generated')
            
            logging.info("-------------------------End validating.-------------------------")

            if hasattr(self, 'training_set_path') and self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA', 'ULLA_OABOA']:
                forward_momentum = self.training_set_path[:, :, -self.manifold.out_dim:]
                forward_momentum_norms = torch.norm(forward_momentum, dim=2).mean(dim=0).cpu().numpy()
                
                backward_momentum = other_dict['v_hist_all']
                backward_momentum_norms = torch.norm(backward_momentum, dim=2).mean(dim=1).cpu().numpy()
                
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
                
                time_steps_forward = np.linspace(0, 1, len(forward_momentum_norms))
                ax1.plot(time_steps_forward, forward_momentum_norms, 'b-', linewidth=2)
                ax1.set_xlabel('Time')
                ax1.set_ylabel('Average Momentum Norm')
                ax1.set_title('Forward Path Momentum Norms')
                ax1.grid(True, alpha=0.3)
                
                backward_momentum_norms = backward_momentum_norms[::-1]
                time_steps_backward = np.linspace(0, 1, len(backward_momentum_norms))   
                ax2.plot(time_steps_backward, backward_momentum_norms, 'r-', linewidth=2)
                ax2.set_xlabel('Time percentage')
                ax2.set_ylabel('Average Momentum Norm')
                ax2.set_title('Backward Path Momentum Norms')
                ax2.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.savefig(self.savefig_dir + f"/Momentum_Norms_val_{epoch}.png", dpi=300)
                plt.close(fig)

    def test(self):
        return

    def sample_on_manifolds(self):
        logging.info(f'Start sampling on manifolds.')
        device = self.device
        if self.network is not None: self.network.to(device)

        logging.info("Start sampling backward SDE.")
        init = self.manifold.uniform_sample(self.config.sample.sample_num).to(device)
        x, x_hist, other_dict = self.SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                         reverse=True,
                                                         score_net=self.network,
                                                         keep_quiet=False, **self.sde_kwargs)

        if self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA', 'ULLA_OABOA']:
            x = x[:, :self.manifold.out_dim]
            x_hist = x_hist[:, :, :self.manifold.out_dim]

        self.calculate_constrain(x)
        x = self.filter_sample(x)

        statistics_gen = self.get_statistics(x).cpu().numpy()
        self.plot_hist(statistics_gen, savefig='generated_final')

        # --- MODIFICATION START: Call JSD calculation and PCA plot ---
        self.calculate_and_log_jsd(statistics_gen)
        self.plot_pca_comparison(x, savefig='generated_final')
        # --- MODIFICATION END ---
        
        plot_idx = list(range(0, 100, 10)) + list(range(90, 101))
        for i in range(self.sde.N+1):
            if (100 * i / self.sde.N in plot_idx) or (i > self.sde.N - 5):
                statistics = self.get_statistics(x_hist[i]).cpu().numpy()
                self.plot_hist(statistics, savefig=f'generating_bwd_{i}')

        statistics_path = self.get_statistics_path(x_hist.cpu().detach()).numpy()
        np.save(f"{self.samples_dir}/{self.dataset_name}_statistics_bwd.npy", statistics_path)

        return