import os
import logging
import matplotlib.pyplot as plt
import torch
import numpy as np

import pandas as pd # for data manipulation
import cartopy.crs as ccrs # for geographic projections
import cartopy.feature as cfeature # for geographic features

from runners.Basic_runner import BasicRunner # for basic runner functionality
from manifolds.Sphere import latlon_to_xyz, xyz_to_latlon # for spherical coordinates conversion
from utils import split_dataset, check_memory, save_model, load_model
from metric import compute_w2_distance_theta_phi, compute_js_distance_2d_histogram

class S2Runner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)  # initialize the basic runner

        self.load_data() # load the dataset

        x_prior = self.manifold.uniform_sample(self.config.sample.sample_num) #sample_num = number of data points
        self.plot_sample(x_prior, savefig = 'prior') # plot the prior distribution samples (uniform distribution on the sphere)

        # training_set_path.shape = (sample_num, N+1, 3), N = number of steps in the SDE
        x_hist = self.training_set_path.clone().transpose(0, 1) # transpose the training set for plotting
        plot_idx = list(range(10)) + list(range(10, 101, 10))  # indices for plotting samples at different steps
        for i in range(self.sde.N+1):  # 100 * i / self.sde.N in [0, 10, 20, ..., 100] is percentage of the diffusion process completed
            if (100 * i / self.sde.N in plot_idx) or (i < 5):
                self.plot_sample(x_hist[i].cpu().numpy(), savefig=f'generating_fwd_{i}')

    def load_data(self):
        self.x_bound, self.y_bound = [-180., 180.], [-90., 90.]  # bounds for longitude and latitude
        
        csv_path = f"./data/S2/earth_data/{self.dataset_name}.csv"
        data_ori = pd.read_csv(csv_path, comment='#', header=0).values.astype("float32")  
        # load the dataset, ignore comment starting with '#', header = 0 implies column names are in the first row
        data_ori = latlon_to_xyz(data_ori)  # convert latitude and longitude to Cartesian coordinates
        self.config.sample.sample_num = data_ori.shape[0]  # set the number of samples to the number of data points
        self.projection = ccrs.PlateCarree(central_longitude=0) # set the projection for plotting

        data_ori = torch.tensor(data_ori, dtype=torch.float32)
        self.data = data_ori.clone()
        self.training_set, self.test_set, self.val_set = split_dataset(data_ori, self.config.seed)

        self.unseen_val_set, self.seen_val_set = self.mark_unseen_data(self.val_set)
        self.unseen_test_set, self.seen_test_set = self.mark_unseen_data(self.test_set)
        self.unseen_data = torch.cat((self.unseen_val_set, self.unseen_test_set), dim=0)

        np.savez(f"{self.samples_dir}/dataset_{self.dataset_name}_{self.config.seed}.npz", 
                 training_set=self.training_set.numpy(), 
                 seen_test_set=self.seen_test_set.numpy(),
                 seen_val_set=self.seen_val_set.numpy(),
                 unseen_test_set=self.unseen_test_set.numpy(),
                 unseen_val_set = self.unseen_val_set.numpy())
        self.plot_sample_all()

        if self.unseen_data.shape[0] == 0:
            logging.info(f'No unseen points in validation set and test set.')
        else:
            logging.info(f'{self.unseen_val_set.shape[0]} unseen points in validation set.')
            logging.info(f'{self.unseen_test_set.shape[0]} unseen points in test set.')
            self.plot_sample(self.training_set, highlight_samples=self.unseen_data.cpu().numpy(), savefig='training-set-unseen-data')

            if self.config.training.include_unseen_data_in_training:
                self.training_set = torch.cat((self.training_set, self.unseen_data), dim=0)
                logging.info(f'Adding {self.unseen_data.shape[0]} unseen points in training set. Training data size: {self.training_set.shape[0]}')

        self.training_set_path, h_val = self.generate_path_dataset(self.training_set, keep_quiet=False)
        # Check memory usage of the training set path
        check_memory(self.training_set_path)

    def mark_unseen_data(self, data_set):
        """
        Mark unseen (or rare, unique) data in the dataset with respect to training set.
        Returns two sets: unseen data and seen data.
        """

        ns = self.config.training.grid_size # grid size for marking unseen data
        lat, lon = xyz_to_latlon(self.training_set)
        lonlat = torch.stack((lon, lat), dim = 1)  # stack longitude and latitude for grid marking
        training_set_hist = torch.histogramdd(lonlat, bins=[ns, ns],
                                              range=[self.x_bound[0], self.x_bound[1], self.y_bound[0], self.y_bound[1]])

        lat, lon = xyz_to_latlon(data_set)
        lonlat = torch.stack((lon, lat), dim=1)
        val_set_hist = torch.histogramdd(lonlat, bins=[ns, ns],
                                         range=[self.x_bound[0], self.x_bound[1], self.y_bound[0], self.y_bound[1]])
        
        cell_index_list = []
        for i in range(ns):
            for j in range(ns):
                # check if the cell is unseen in the training set
                if val_set_hist.hist[i, j] > 0:
                    found = (
                    training_set_hist.hist[i, j] +
                    training_set_hist.hist[i - 1, j] +
                    training_set_hist.hist[(i + 1) % ns, j] +
                    training_set_hist.hist[i, j - 1] +
                    training_set_hist.hist[i, (j + 1) % ns]
                    ) > 0
                    if not found:
                        cell_index_list.append((i, j))
            
        if len(cell_index_list) > 0:
            lon_idx = torch.floor((lon - self.x_bound[0]) / ((self.x_bound[1] - self.x_bound[0]) / ns)).int()
            lat_idx = torch.floor((lat - self.y_bound[0]) / ((self.y_bound[1] - self.y_bound[0]) / ns)).int()
            unseen_idx = None
            for ll in cell_index_list:
                idx = torch.logical_and(lon_idx == ll[0], lat_idx == ll[1]).nonzero()
                unseen_idx = idx if unseen_idx is None else torch.cat((unseen_idx, idx))

            mask = torch.ones(data_set.shape[0], dtype=torch.bool)
            mask[unseen_idx.squeeze()] = False # mark unseen data points as False in the mask
            # if there are unseen data points, return them along with the seen data points
            return data_set[unseen_idx.squeeze(), :].reshape(-1,3), data_set[mask, :]
        else:
            # if there are no unseen data points, return empty tensor and the original dataset
            return torch.empty(0), data_set            
            



    def plot_sample(self, samples, values=None, highlight_samples = None, savefig = None):
        if isinstance(samples, torch.Tensor): samples = samples.detach().cpu().numpy()
        fig = plt.figure()
        lat, lon = xyz_to_latlon(samples)  # convert Cartesian coordinates to latitude and longitude
        ax = fig.add_subplot(1,1,1, projection=self.projection)  # create a subplot with the specified projection

        if values is not None: # if values is not None, use them for coloring the points
            color = values.detach().cpu().numpy()
            scatter = ax.scatter(lon, lat , s=0.3, c=color, cmap = 'coolwarm', label= 'Samples (via Value param)')
        else:
            ax.scatter(lon, lat, s=0.3, color='red', alpha = 1.0, label = 'Samples')
        
        ax.add_feature(cfeature.LAND, zorder = 0, facecolor="#e0e0e0") # add land feature
        ax.add_feature(cfeature.OCEAN, zorder = 0, facecolor="#b0c4de") # add ocean feature
        ax.add_feature(cfeature.COASTLINE, zorder = 1, linewidth=0.5) # add coastline feature
        ax.set_global()

        if highlight_samples is not None:
            lat_point, lon_point = xyz_to_latlon(highlight_samples)
            ax.scatter(lon_point, lat_point, s=3.0, c='green',label='Unseen samples (val-test)')
        
        ax.set_xlabel('Longitude (degrees)')
        ax.set_ylabel('Latitude (degrees)')
        ax.set_title(f'{samples.shape[0]} Sample plotting on Earth data')
        ax.legend()

        plt.savefig(self.savefig_dir + f"/samples_latlon{'_values' if values is not None else ''}_{savefig}.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def plot_sample_all(self):
        fig = plt.figure()
        lat_training, lon_training = xyz_to_latlon(self.training_set.numpy())
        lat_test, lon_test = xyz_to_latlon(self.test_set.numpy())
        lat_val, lon_val = xyz_to_latlon(self.val_set.numpy())
        ax = fig.add_subplot(1,1,1, projection=self.projection)
        ax.scatter(lon_training, lat_training, s=0.3, color='red', alpha = 1.0, label='Training Set')
        ax.scatter(lon_test, lat_test, s=0.3, color='green', alpha = 1.0, label='Test Set')
        ax.scatter(lon_val, lat_val, s=0.3, color='blue', alpha = 1.0, label='Validation Set')

        ax.add_feature(cfeature.LAND, zorder = 0, facecolor="#e0e0e0")
        ax.add_feature(cfeature.OCEAN, zorder = 0, facecolor="#b0c4de")
        ax.add_feature(cfeature.COASTLINE, zorder = 1, linewidth=0.5)
        ax.set_global()

        ax.set_xlabel('Longitude (degrees)')
        ax.set_ylabel('Latitude (degrees)')
        ax.set_title(f"Latitude and Longditude of {self.dataset_name} Dataset")
        plt.legend()
        plt.savefig(self.savefig_dir + f"/samples_latlon_all.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def validate(self, mode=None, epoch=0, **kwargs):
        if mode == 'start':
            self.best_nll_K_val = torch.inf
        elif mode == 'end':
            pass
        else:
            logging.info(f"-------------------------Start validating: Epoch {epoch}-------------------------")
            # Log cumulative times
            logging.info(f"Time for trajectory generation: {self.trajectory_gen_time:.2f} seconds.")
            logging.info(f"Cumulative training time so far: {self.cumulative_training_time:.2f} seconds.")
            # =======================================================
            if self.unseen_val_set.shape[0] != 0:
                nll_K_unseen_val, nll_unseen_val, _, _, _  = self.negative_log_likelihood_fn(self.unseen_val_set)
                self.tb_logger.add_scalar('nll_unseen_val', nll_unseen_val, global_step=epoch)
                self.tb_logger.add_scalar('nll_K_unseen_val', nll_K_unseen_val, global_step=epoch)      

            if self.unseen_test_set.shape[0] != 0:
                nll_K_unseen_test, nll_unseen_test, _, _, _ = self.negative_log_likelihood_fn(self.unseen_test_set)
                self.tb_logger.add_scalar('nll_K_unseen_test', nll_K_unseen_test, global_step=epoch)          

            nll_K_seen_val, nll_seen_val, _, _, _  = self.negative_log_likelihood_fn(self.seen_val_set)
            self.tb_logger.add_scalar('nll_seen_val', nll_seen_val, global_step=epoch)
            self.tb_logger.add_scalar('nll_K_seen_val', nll_K_seen_val, global_step=epoch)

            nll_K_seen_test, nll_seen_test, _, _, _ = self.negative_log_likelihood_fn(self.seen_test_set)
            self.tb_logger.add_scalar('nll_K_seen_test', nll_K_seen_test, global_step=epoch)


            nll_K_train, nll_train, loss_part, _, _ = self.negative_log_likelihood_fn(kwargs["batch"])
            self.tb_logger.add_scalar('nll_train', nll_train, global_step=epoch)
            self.tb_logger.add_scalar('nll_K_train', nll_K_train, global_step=epoch)
            logging.info(f'nll_K_train:{nll_K_train:.4f}, nll_train: {nll_train:.4f}, loss part:{loss_part:.4f}.')
 
            nll_K_val, nll_val, loss_part_val, _, _ = self.negative_log_likelihood_fn(self.val_set, return_mean=False)
            self.plot_sample(self.val_set, values=nll_K_val, savefig=f"{epoch}_val")
            self.tb_logger.add_scalar('nll_val', nll_val.mean(), global_step=epoch)
            self.tb_logger.add_scalar('nll_K_val', nll_K_val.mean(), global_step=epoch)
            self.tb_logger.add_scalar('loss_val', loss_part_val.mean(), global_step=epoch)
            logging.info(f'nll_K_val: {nll_K_val.mean() :.4f}, loss_val:{loss_part_val.mean() :.4f}, nll_val: {nll_val.mean() :.4f}.')
            nll_K_val = nll_K_val.mean()

            nll_K_test, nll_test, loss_part_test, _ , _  = self.negative_log_likelihood_fn(self.test_set)
            self.tb_logger.add_scalar('nll_K_test', nll_K_test, global_step=epoch)
            logging.info(f'nll_K_test: {nll_K_test:.4f}, loss_test:{loss_part_test :.4f}, nll_test: {nll_test:.4f}')

            if nll_K_val < self.best_nll_K_val and epoch > int(self.config.training.min_checkpoint_epoch_ratio * self.total_epochs):
                self.best_nll_K_val = nll_K_val
                self.best_checkpoint_epoch = epoch
                logging.info(f'Epoch {epoch}: best validation nll {nll_K_val :.4f} on validation set.')
                logging.info(f'Epoch {epoch}: test nll {nll_K_test :.4f} with the best performance on validation set.')
                logging.info(f'Epoch {epoch}: training nll {nll_K_train:.4f} with the best performance on validation set.')
                save_model(self.validate_dir, self.network, name=f"model_best_nll_val.pt")            
            
            # sampling using reverse SDE
            init = self.manifold.uniform_sample(self.config.sample.sample_num).to(self.device)
            x, x_hist, other_dict = self.SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                          reverse=True,
                                                          score_net=self.network,
                                                          keep_quiet=True, **self.sde_kwargs)
            self.plot_sample(x.cpu().numpy(), savefig=f'sample_epoch_{epoch}')
            if self.config.sample.sampler == 'CHMC':
                x = x[:, :self.manifold.out_dim]  # downsample for CHMC

            # Convert to (theta, phi) and calculate W2 distance
            lat_gen, lon_gen = xyz_to_latlon(x.cpu().numpy())
            generated_samples_theta_phi = np.stack([lat_gen, lon_gen], axis=1)

            full_data = torch.cat([self.training_set, self.val_set, self.test_set], dim=0)
            lat_data, lon_data = xyz_to_latlon(full_data.cpu().numpy())
            full_data_theta_phi = np.stack([lat_data, lon_data], axis=1)

            w2_dist_theta_phi = compute_w2_distance_theta_phi(generated_samples_theta_phi, full_data_theta_phi)
            self.tb_logger.add_scalar('w2_dist_theta_phi', w2_dist_theta_phi, global_step=epoch)
            logging.info(f'Epoch {epoch}: W2 distance (theta, phi) between generated samples and full dataset: {w2_dist_theta_phi:.4f}')
            # =============================================================================

            # Calculate Jensen-Shannon distance on 2D histograms
            # Define fixed ranges for theta (latitude) and phi (longitude) for consistency
            hist_ranges = [[-90, 90], [-180, 180]] 
            js_dist_hist = compute_js_distance_2d_histogram(generated_samples_theta_phi, full_data_theta_phi, bins=64, ranges=hist_ranges)
            self.tb_logger.add_scalar('js_dist_hist', js_dist_hist, global_step=epoch)
            logging.info(f'Epoch {epoch}: JS distance on 2D histogram: {js_dist_hist:.4f}')
            # ========================================================================

            logging.info(f"-------------------------End validating: Epoch {epoch}-------------------------")
            


            # Calculate momentum norms for forward and backward paths
            if hasattr(self, 'training_set_path') and self.config.sample.sampler == 'CHMC':
                # Forward path momentum norms
                forward_momentum = self.training_set_path[:, :, self.manifold.out_dim:]  # Extract momentum part
                forward_momentum_norms = torch.norm(forward_momentum, dim=2).mean(dim=0).cpu().numpy()
                
                # Backward path momentum norms
                backward_momentum = other_dict['v_hist_all']
                backward_momentum_norms = torch.norm(backward_momentum, dim=2).mean(dim=1).cpu().numpy()
                
                # Plot momentum norms
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
                
                # Forward path
                time_steps_forward = np.linspace(0, 1, len(forward_momentum_norms))
                ax1.plot(time_steps_forward, forward_momentum_norms, 'b-', linewidth=2)
                ax1.set_xlabel('Time')
                ax1.set_ylabel('Average Momentum Norm')
                ax1.set_title('Forward Path Momentum Norms')
                ax1.grid(True, alpha=0.3)
                
                # Backward path (reverse order for proper visualization)
                backward_momentum_norms = backward_momentum_norms[::-1]
                time_steps_backward = np.linspace(0, 1, len(backward_momentum_norms))   
                ax2.plot(time_steps_backward, backward_momentum_norms, 'r-', linewidth=2)
                ax2.set_xlabel('Time percentage')
                ax2.set_ylabel('Average Momentum Norm')
                ax2.set_title('Backward Path Momentum Norms')
                ax2.grid(True, alpha=0.3)

                
                plt.tight_layout()
                plt.savefig(self.savefig_dir + f"/Momentum_Norms_val_{epoch}.png", dpi=300)
                plt.close(fig)
            # --- MODIFICATION END ---
    def test(self):
        logging.info(f"-------------------------Start testing-------------------------")
        
        model_path = os.path.join(self.validate_dir, "model_best_nll_val.pt")
        if self.config.if_train and os.path.exists(model_path):
            self.network = load_model(model_path).to(self.device)
        else:
            self.network.to(self.device)
        
        if self.config.calculate_mesh_nll_earth:
            self.plot_likelihood_earth(grid_N=400)
        
        logging.info('Calculate likelihood: ')

        nll_K_test, nll_test, _, _, _  = self.negative_log_likelihood_fn(self.test_set, return_mean=False)
        self.plot_sample(self.test_set,  values=nll_K_test, savefig="test")

        nll_K_val, nll_val, _, _, _  = self.negative_log_likelihood_fn(self.val_set, return_mean=False)

        nll_K_train, nll_train, _, _,_  = self.negative_log_likelihood_fn(self.training_set, return_mean=False)
        self.plot_sample(self.training_set,  values=nll_K_train, savefig="training")

        nll_K_train_np = nll_K_train.mean().detach().cpu().numpy()
        nll_K_val_np = nll_K_val.mean().detach().cpu().numpy()
        nll_K_test_np = nll_K_test.mean().detach().cpu().numpy()

        if self.config.if_train:
            logging.info(f'On epoch {self.best_checkpoint_epoch}, the best checkpoints: train nll {nll_K_train_np :.4f}, val nll {nll_K_val_np :.4f}, test nll {nll_K_test_np :.4f}.')
        else:
            logging.info(f', the best checkpoints: train nll {nll_K_train_np :.4f}, val nll {nll_K_val_np :.4f}, test nll {nll_K_test_np :.4f}.')
        
        # Generate samples and calculate W2 distance
        logging.info('Generating samples for W2 distance calculation in test phase.')
        init = self.manifold.uniform_sample(self.config.sample.sample_num).to(self.device)
        x, _, _ = self.SDE_sampler_manifolds(self.sde, self.manifold, init,
                                              reverse=True,
                                              score_net=self.network,
                                              keep_quiet=True, **self.sde_kwargs)
        if self.config.sample.sampler == 'CHMC':
                x = x[:, :self.manifold.out_dim]

        lat_gen, lon_gen = xyz_to_latlon(x.cpu().numpy())
        generated_samples_theta_phi = np.stack([lat_gen, lon_gen], axis=1)

        full_data = torch.cat([self.training_set, self.val_set, self.test_set], dim=0)
        lat_data, lon_data = xyz_to_latlon(full_data.cpu().numpy())
        full_data_theta_phi = np.stack([lat_data, lon_data], axis=1)

        w2_dist_theta_phi = compute_w2_distance_theta_phi(generated_samples_theta_phi, full_data_theta_phi)
        logging.info(f'Test phase W2 distance (theta, phi) between generated samples and full dataset: {w2_dist_theta_phi:.4f}')

        # Calculate Jensen-Shannon distance on 2D histograms in the test phase
        hist_ranges = [[-90, 90], [-180, 180]]
        js_dist_hist = compute_js_distance_2d_histogram(generated_samples_theta_phi, full_data_theta_phi, bins=64, ranges=hist_ranges)
        logging.info(f'Test phase JS distance on 2D histogram: {js_dist_hist:.4f}')
        # ========================================================================

        logging.info(f"-------------------------End testing-------------------------")
        return
    
    def sample_on_manifolds(self):
        logging.info(f'Start sampling on manifolds.')
        if self.network is not None:
            self.network.to(self.device)

        # backward
        init = self.manifold.uniform_sample(self.config.sample.sample_num).to(self.device)
        x, x_hist, other_dict = self.SDE_sampler_manifolds(self.sde, self.manifold, init, 
                                                     reverse=True,
                                                     score_net=self.network, **self.sde_kwargs)
        self.plot_sample(x.cpu().numpy(), savefig='generated')
        plot_idx = list(range(0, 100, 10)) + list(range(90, 101))
        for i in range(self.sde.N+1):
            if (100 * i / self.sde.N in plot_idx) or (i > self.sde.N - 5):
                self.plot_sample(x_hist[i].cpu().numpy(), savefig=f'generating_bwd_{i}')

        np.save(f"{self.samples_dir}/{self.dataset_name}_samples_generated.npy", x.cpu().numpy())
        np.save(f"{self.samples_dir}/{self.dataset_name}_samples_test_set.npy", self.test_set.numpy())
        return