import logging
import torch
import numpy as np
import os
import trimesh
import plotly.graph_objs as go
import plotly.offline as offline

from runners.Basic_runner import BasicRunner
from utils import (
    split_dataset,
    check_memory,
    run_func_in_batches,
    load_model,
    save_model)
from datasets.get_mesh_data import refine_dataset_SDF
from scipy.stats import entropy



class SDFRunner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)
        self.load_data()

        """---------------------------exhibit dataset--------------------------"""
        samples_test = self.training_set[:self.config.sample.sample_num].detach()
        self.plot_sample(samples_test.cpu().numpy(), savefig="training_set")
        self.plot_histogram_on_surface(samples=samples_test.cpu().numpy(), savefig='training_set')

        if self.config.if_train or self.config.if_sample:
            x_hist = self.training_set_path[:self.config.sample.sample_num].clone().transpose(0,1)
            if self.config.sample.sampler == 'CHMC':
                x_hist = x_hist[:,:,:self.manifold.out_dim]
                # Truncate momentum dimension to match manifold output dimension

            x = x_hist[-1].detach().cpu().numpy()
            self.plot_sample(x, savefig='forward_end')
            self.plot_histogram_on_surface(x, savefig='forward_end')
            plot_idx = list(range(10)) + list(range(10, 101, 10))
            for i in range(self.sde.N+1):
                if (100 * i / self.sde.N in plot_idx) or (i < 5):
                    x_temp = x_hist[i].detach().cpu().numpy()
                    self.plot_sample(x_temp, savefig=f'generating_fwd_{i}')
                    self.plot_histogram_on_surface(x_temp, savefig=f'generating_fwd_{i}')

    def load_data(self):
        data_ori = torch.tensor(np.load(f"./data/{self.obj}/{self.dataset_name}_refined.npy"))

        uniform_sample = self.manifold.uniform_sample(self.config.sample.sample_num).to(self.device)
        self.uniform_sample = refine_dataset_SDF(self.manifold.constrain_fn, uniform_sample)
        self.data_set = data_ori[torch.randperm(data_ori.shape[0])].clone()
        self.training_set, self.test_set, self.val_set = split_dataset(self.data_set, self.config.seed)

        if self.config.if_train or self.config.if_sample:
            path_set_ori = self.generate_path_dataset(self.training_set, keep_quiet=False)
            self.training_set_path, h_val = path_set_ori
            check_memory(self.training_set_path)

        self.mesh = self.manifold.mesh
        if self.obj == "bunny":
            self.scene_dict = dict(xaxis=dict(range=(-1.05, 1.05), autorange=False),
                            yaxis=dict(range=(-1.05, 1.05), autorange=False),
                            zaxis=dict(range=(-1.05, 1.05), autorange=False),
                            aspectratio=dict(x=1, y=1, z=1),
                            camera=dict(
                            eye=dict(x=-0.5, y=0, z=-2),
                            up=dict(x=0, y=1, z=0),
                            center=dict(x=0, y=0, z=0)))
        else:
            self.scene_dict = dict(xaxis=dict(range=(-1.05, 1.05), autorange=False),
                            yaxis=dict(range=(-1.05, 1.05), autorange=False),
                            zaxis=dict(range=(-1.05, 1.05), autorange=False),
                            aspectratio=dict(x=1, y=1, z=1),
                            camera=dict(
                            eye=dict(x=-1, y=1, z=1),
                            up=dict(x=0, y=1, z=0),
                            center=dict(x=0, y=0, z=0)))

    def plot_sample(self, samples, savefig=None):
        verts = self.mesh.vertices
        I, J, K = self.mesh.faces.transpose()

        trace = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
                            i=I, j=J, k=K, alphahull=5, opacity=0.4, color='cyan'),
                 go.Scatter3d(x=samples[:, 0], y=samples[:, 1], z=samples[:, 2], mode='markers', marker=dict(size=3))]
        fig = go.Figure(data=trace)
        fig.update_layout(title=f'{samples.shape[0]} scatters',
                          scene=self.scene_dict, width=1400, height=1400, showlegend=True)

        filename0 = self.savefig_dir + f"/Samples_{savefig}"
        offline.plot(fig, filename=f'{filename0}.html', auto_open=False)
        fig.write_image(f"{filename0}.png")

    def compute_histogram_on_surface(self, samples):
        """
        Computes the probability distribution (histogram) of samples over the mesh faces.
        """
        _, _, closest_faces = trimesh.proximity.closest_point(self.mesh, samples)
        unique_faces, counts = np.unique(closest_faces, return_counts=True)
        probs = np.zeros(len(self.mesh.faces))
        probs[unique_faces] = counts / len(samples)
        # Add a small epsilon to ensure there are no zeros in the distribution
        # which would cause issues with KL divergence (log(0)).
        epsilon = 1e-12
        return probs + epsilon

    def plot_histogram_on_surface(self, samples, colorscale=None, savefig=None):

        verts = self.mesh.vertices
        I, J, K = self.mesh.faces.transpose()

        closest_points, _, closest_faces = trimesh.proximity.closest_point(self.mesh, samples)
        unique_faces, counts = np.unique(closest_faces, return_counts=True)
        probs = np.zeros(len(self.mesh.faces))
        probs[unique_faces] = counts / len(samples)
        densities = probs / self.mesh.area_faces
        densities[np.isnan(densities)] = 0

        cmin, cmax = -0.1, np.percentile(densities, 95) if colorscale is None else colorscale

        traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
                            i=I, j=J, k=K, name='Samples_hist',
                            opacity=1.0, intensity=densities, intensitymode="cell", colorscale="Viridis",
                            cmin=cmin, cmax=cmax)]
        layout = go.Layout(title=f'Histgram of {samples.shape[0]} scatters', scene=self.scene_dict, width=1400, height=1400, showlegend=True)
        fig = go.Figure(data=traces, layout=layout)

        filename0 = self.savefig_dir + f"/Histgram_{savefig}"
        offline.plot(fig, filename=f'{filename0}.html', auto_open=False)
        fig.write_image(f"{filename0}.png")

    def validate(self, mode=None, epoch=0, **kwargs):
        if mode == 'start':
            self.best_nll_K_val = torch.inf
        elif mode == 'end':
            pass
        else:
            logging.info(f"-------------------------Start validating: Epoch {epoch}-------------------------")
            # calulate likelihood
            nll_K_train, nll_train, loss_part = self.negative_log_likelihood_fn(kwargs["batch"])
            self.tb_logger.add_scalar('nll_train', nll_train, global_step=epoch)
            self.tb_logger.add_scalar('nll_K_train', nll_K_train, global_step=epoch)
            logging.info(f'nll_K_train:{nll_K_train:.4f}, nll_train: {nll_train:.4f}, loss part:{loss_part:.4f}.')

            nll_K_val, nll_val, loss_part_val = self.negative_log_likelihood_fn(self.val_set)
            self.tb_logger.add_scalar('nll_val', nll_val, global_step=epoch)
            self.tb_logger.add_scalar('nll_K_val', nll_K_val, global_step=epoch)
            self.tb_logger.add_scalar('loss_val', loss_part_val, global_step=epoch)
            logging.info(f'nll_K_val: {nll_K_val:.4f}, loss_val:{loss_part_val :.4f}, nll_val: {nll_val:.4f}.')
                        
            # save best model
            if nll_K_val < self.best_nll_K_val and epoch >= int(self.config.training.min_checkpoint_epoch_ratio * self.total_epochs):
                self.best_nll_K_val = nll_K_val
                self.best_checkpoint_epoch = epoch
                # nll_K_test, nll_test, loss_part_test = self.negative_log_likelihood_fn(self.test_set)
                # logging.info(f'nll_K_test: {nll_K_test:.4f}, loss_test:{loss_part_test :.4f}, nll_test: {nll_test:.4f}')

                logging.info(f'Epoch {epoch}: best validation nll {nll_K_val :.4f} on validation set.')
                # logging.info(f'Epoch {epoch}: test nll {nll_K_test :.4f} with the best performance on validation set.')
                save_model(self.validate_dir, self.network, name=f"model_best_nll_val.pt")

            # sample
            init = self.uniform_sample.to(self.device)
            x, x_hist, other_dict = self.SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                          reverse=True,
                                                          score_net=self.network,
                                                          keep_quiet=True, **self.sde_kwargs)
            if self.config.sample.sampler == 'CHMC':
                x = x[:,:self.manifold.out_dim]
                # Truncate momentum dimension to match manifold output dimension

            x = x.cpu().numpy()
            self.plot_sample(x, savefig=f'val_{epoch}_generated')
            self.plot_histogram_on_surface(x, savefig=f'val_{epoch}_generated')

            # --- JS Divergence Calculation ---
            logging.info(f"Calculating JS Divergence for epoch {epoch}.")
            # 1. Compute histogram for generated samples
            probs_generated = self.compute_histogram_on_surface(x)
            
            # 2. Compute histogram for original dataset (using a subset for comparison)
            original_samples = self.data_set.cpu().numpy()
            probs_original = self.compute_histogram_on_surface(original_samples)

            # 3. Compute JS Divergence
            M = 0.5 * (probs_generated + probs_original)
            js_divergence = 0.5 * (entropy(probs_generated, M) + entropy(probs_original, M))
            
            self.tb_logger.add_scalar('JS_divergence_val', js_divergence, global_step=epoch)
            logging.info(f'JS Divergence (generated vs. original): {js_divergence:.6f}')
            # --- End of JS Divergence Calculation ---

            logging.info("-------------------------End validating.-------------------------")

    def test(self):
        logging.info(f'Start testing models.')
        device = self.device
        self.manifold.model.to(device)

        model_path = os.path.join(self.validate_dir, "model_best_nll_val.pt")
        if self.config.if_train and os.path.exists(model_path):
            self.network = load_model(model_path).to(device)
        else:
            self.network.to(device)

        logging.info('Calculate likelihood on test data.')
        nll_K_test, nll_test, _  = self.negative_log_likelihood_fn(self.test_set[:5000])
        logging.info(f'nll_K_test: {nll_K_test:.4f}, nll_test: {nll_test:.4f}')
        if self.config.if_train:
            logging.info(f'On epoch {self.best_checkpoint_epoch}, the best test nll_K: {nll_K_test:.4f}')
        else:
            logging.info(f', the best test nll_K: {nll_K_test:.4f}')


        if self.config.calculate_mesh_nll_bunny_spot:
            logging.info('Calculate likelihood on mesh.')
            mesh = trimesh.load(f"./data/{self.obj}/{self.obj}_mesh_complex.ply")
            point = torch.tensor(mesh.vertices).float().to(device)

            # refined_samples = refine_dataset_SDF(self.manifold.model, point)
            func_refine = lambda x: refine_dataset_SDF(self.manifold.model, x)
            refined_samples = run_func_in_batches(func_refine, point, 50000, 3)
            nll_K_val, nll_val, _   = self.negative_log_likelihood_fn(refined_samples, return_mean=False)
            np.save(f"{self.samples_dir}/{self.dataset_name}_learned_likelihood_complex.npy", nll_K_val.cpu().detach().numpy())
            logging.info(f"Shape of nll_K_val and the refined data: {nll_K_val.shape}, {mesh.vertices.shape}.")

        return

    def sample_on_manifolds(self):
        logging.info(f'Start sampling on manifolds.')
        device = self.device
        if self.network is not None: self.network.to(device)
        self.manifold.model.to(device)

        # backward
        logging.info("Start sampling backward SDE.")
        init = self.uniform_sample.to(device)
        x, x_hist, other_dict = self.SDE_sampler_manifolds(self.sde, self.manifold, init, 
                                                      reverse=True,
                                                      score_net=self.network,
                                                      keep_quiet=False, **self.sde_kwargs)
        if self.config.sample.sampler == 'CHMC':
            x = x[:, :self.manifold.out_dim]
            x_hist = x_hist[:, :, :self.manifold.out_dim]

        self.calculate_constrain(x)
        x = x.cpu().numpy()
        self.plot_sample(x, savefig='generated')
        self.plot_histogram_on_surface(x, savefig='generated')
        plot_idx = list(range(0, 100, 10)) + list(range(90, 101))
        for i in range(self.sde.N+1):
            if (100 * i / self.sde.N in plot_idx) or (i > self.sde.N - 5):
                x_temp = x_hist[i].cpu().numpy()
                self.plot_sample(x_temp, savefig=f'generating_bwd_{i}')
                self.plot_histogram_on_surface(x_temp, savefig=f'generating_bwd_{i}')

        np.save(f"{self.samples_dir}/{self.dataset_name}_samples_generated.npy", x)
        
        return

