import logging
from runners.Basic_runner import BasicRunner
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from utils import (
    split_dataset,
    check_memory,
    load_model,
    run_func_in_batches, 
    save_model)
from sampling import SDE_sampler_manifolds
import pandas as pd
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from manifolds.Sphere import latlon_to_xyz, xyz_to_latlon


class S2Runner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)

        self.load_data()

        """---------------------------exhibit dataset--------------------------"""
        x_prior = self.manifold.uniform_sample(self.config.sample.sample_num)
        self.plot_sample(x_prior, savefig='prior')

        x_hist = self.training_set_path.clone().transpose(0, 1)
        plot_idx = list(range(10)) + list(range(10, 101, 10))
        for i in range(self.sde.N+1):
            if (100 * i / self.sde.N in plot_idx) or (i < 5):
                self.plot_sample(x_hist[i].cpu().numpy(), savefig=f'generating_fwd_{i}')

    def load_data(self):
        self.x_bound, self.y_bound = [-180., 180.], [-90., 90.]

        csv_path = f"./data/S2/earth_data/{self.dataset_name}.csv"
        data_ori = pd.read_csv(csv_path, comment="#", header=0).values.astype("float32")
        data_ori = latlon_to_xyz(data_ori)
        self.config.sample.sample_num = data_ori.shape[0]
        self.projection = ccrs.PlateCarree(central_longitude=0)

        data_ori = torch.tensor(data_ori, dtype=torch.float32)
        self.training_set, self.test_set, self.val_set = split_dataset(data_ori, self.config.seed)

        self.unseen_val_set, self.seen_val_set = self.mark_unseen_data(self.val_set)
        self.unseen_test_set, self.seen_test_set = self.mark_unseen_data(self.test_set)
        self.unseen_data = torch.cat((self.unseen_val_set, self.unseen_test_set), dim=0)

        np.savez(f"{self.samples_dir}/dataset_{self.dataset_name}_{self.config.seed}.npz", 
                 training_set=self.training_set.numpy(), 
                 seen_test_set=self.seen_test_set.numpy(),
                 seen_val_set=self.seen_val_set.numpy(),
                 unseen_test_set=self.unseen_test_set.numpy(),
                 unseen_val_set = self.unseen_val_set.numpy())
        self.plot_sample_all()

        if self.unseen_data.shape[0] == 0:
            logging.info(f'No unseen points in validation set and test set.')
        else:
            logging.info(f'{self.unseen_val_set.shape[0]} unseen points in validation set.')
            logging.info(f'{self.unseen_test_set.shape[0]} unseen points in test set.')
            self.plot_sample(self.training_set, highlight_samples=self.unseen_data.cpu().numpy(), savefig='training-set-unseen-data')

            if self.config.training.include_unseen_data_in_training:
                self.training_set = torch.cat((self.training_set, self.unseen_data), dim=0)
                logging.info(f'Adding {self.unseen_data.shape[0]} unseen points in training set. Training data size: {self.training_set.shape[0]}')

        self.training_set_path = self.generate_path_dataset(self.training_set, keep_quiet=False)
        check_memory(self.training_set_path)

    def mark_unseen_data(self, data_set):
        ns = self.config.training.grid_size
        lat, lon = xyz_to_latlon(self.training_set)
        lonlat = torch.stack((lon, lat), dim=1)
        training_set_hist = torch.histogramdd(lonlat, bins=[ns, ns],
                                              range=[self.x_bound[0], self.x_bound[1], self.y_bound[0], self.y_bound[1]])

        lat, lon = xyz_to_latlon(data_set)
        lonlat = torch.stack((lon, lat), dim=1)
        val_set_hist = torch.histogramdd(lonlat, bins=[ns, ns],
                                         range=[self.x_bound[0], self.x_bound[1], self.y_bound[0], self.y_bound[1]])

        cell_index_list = []
        for i in range(ns):
            for j in range(ns):
                if val_set_hist.hist[i, j] > 0:
                    found = (training_set_hist.hist[i, j] + training_set_hist.hist[i - 1, j] + training_set_hist.hist[
                        (i + 1) % ns, j] + training_set_hist.hist[i, j - 1] + training_set_hist.hist[
                                 i, (j + 1) % ns]) > 0
                    if not found:
                        cell_index_list.append([i, j])

        if len(cell_index_list) > 0:
            lon_idx = torch.floor((lon - self.x_bound[0]) / ((self.x_bound[1] - self.x_bound[0]) / ns)).int()
            lat_idx = torch.floor((lat - self.y_bound[0]) / ((self.y_bound[1] - self.y_bound[0]) / ns)).int()
            unseen_idx = None
            for ll in cell_index_list:
                idx = torch.logical_and(lon_idx == ll[0], lat_idx == ll[1]).nonzero()
                unseen_idx = idx if unseen_idx is None else torch.cat((unseen_idx, idx))

            mask = torch.ones(data_set.shape[0], dtype=torch.bool)
            mask[unseen_idx.squeeze()] = False
            return data_set[unseen_idx.squeeze(), :].reshape(-1,3), data_set[mask, :]
        else:
            return torch.empty(0), data_set

    def plot_likelihood_earth(self, grid_N):
        fig = plt.figure()
        lat, lon = torch.linspace(-90, 90, grid_N), torch.linspace(-180, 180, grid_N * 2)

        ax = plt.axes(projection=self.projection)

        lat_grid, lon_grid = torch.meshgrid(lat, lon, indexing="ij")
        latlon_grid = torch.cat([lat_grid.reshape(-1, 1), lon_grid.reshape(-1, 1)], dim=-1)
        samples_grid = latlon_to_xyz(latlon_grid).to(self.device)

        nll_func = lambda x: self.negative_log_likelihood_fn(x, return_mean=False)[0]
        nll = run_func_in_batches(nll_func, samples_grid, max_batch_size=10000, out_dim=1)
        nll = nll.detach().cpu().numpy()

        np.save(f"{self.samples_dir}/{self.dataset_name}_learned_nll_grid.npy", nll.reshape(grid_N, grid_N*2))

        density = np.exp(-nll).reshape(grid_N, grid_N*2)

        # np.save(f"{self.samples_dir}/{self.dataset_name}_learned_density_grid.npy", density)

        cs = ax.contourf(
                        lon.detach().cpu().numpy(),
                        lat.detach().cpu().numpy(),
                        density,
                        levels=np.linspace(0, 1, 11),
                        alpha=0.7,
                        extend="max",
                        cmap="BuGn",
                        antialiased=True,
                        )
        
        cbar = plt.colorbar(cs, ax=ax, pad=0.01, ticks=[0, 1])
        cbar.ax.set_yticklabels(["0", "$\geq$1"])
        cbar.ax.set_ylabel("likelihood", fontsize=18, rotation=270, labelpad=10)
        ax.tick_params(axis="both", which="both", direction="in", length=3)
        cbar.ax.tick_params(axis="both", which="both", direction="in", length=3)
        cbar.set_alpha(0.7)
        # cbar.draw_all()

        ax.set_xlim([-180, 180])
        ax.set_xlim([-90, 90])
        ax.set_xlabel('Longitude')
        ax.set_ylabel('Latitude')
        ax.set_title('Potential')

        lat_test_set, lon_test_set = xyz_to_latlon(self.test_set.numpy())
        ax.scatter(lon_test_set, lat_test_set, s=0.1, c='red', alpha=0.3, transform=ccrs.PlateCarree())

        ax.add_feature(cfeature.LAND, zorder=0, facecolor="#e0e0e0")
        ax.set_global()

        plt.savefig(self.savefig_dir + f"/potential_likelihood.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def plot_sample(self, samples, values=None, highlight_samples=None, savefig=None):
        if isinstance(samples, torch.Tensor): samples = samples.detach().cpu().numpy()
        fig = plt.figure()
        lat, lon = xyz_to_latlon(samples)
        ax = fig.add_subplot(111, projection=self.projection)

        if values is not None:
            color = values.detach().cpu().numpy()
            scatter = ax.scatter(lon, lat, s=0.3, c=color, cmap='coolwarm')
            cbar = fig.colorbar(scatter, ax=ax)
        else:
            ax.scatter(lon, lat, s=0.3, c='red', alpha=1.)

        ax.add_feature(cfeature.LAND, zorder=0, facecolor="#e0e0e0")
        ax.set_global()

        if highlight_samples is not None:
            lat_point, lon_point = xyz_to_latlon(highlight_samples)
            ax.scatter(lon_point, lat_point, s=3.0, c='green')

        ax.set_xlabel('Longitude')
        ax.set_ylabel('Latitude')
        ax.set_title(f'LatLon of {samples.shape[0]} samples')

        plt.savefig(self.savefig_dir + f"/samples_latlon{'_values' if values is not None else ''}_{savefig}.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def plot_sample_all(self):
        fig = plt.figure()
        lat_training, lon_training = xyz_to_latlon(self.training_set.numpy())
        lat_test, lon_test = xyz_to_latlon(self.test_set.numpy())
        lat_val, lon_val = xyz_to_latlon(self.val_set.numpy())
        ax = fig.add_subplot(111, projection=self.projection)
        ax.scatter(lon_training, lat_training, s=0.3, c='red', alpha=1., label="training set")
        ax.scatter(lon_test, lat_test, s=0.3, c='green', alpha=1., label="test set")
        ax.scatter(lon_val, lat_val, s=0.3, c='blue', alpha=1., label="val set")
        
        ax.add_feature(cfeature.LAND, zorder=0, facecolor="#e0e0e0")
        ax.set_global()

        ax.set_xlabel('Longitude')
        ax.set_ylabel('Latitude')
        ax.set_title(f'LatLon of the dataset')
        plt.legend()
        plt.savefig(self.savefig_dir + f"/samples_latlon_all.png", dpi=300, bbox_inches='tight')
        plt.close(fig)


    def validate(self, mode=None, epoch=0, **kwargs):
        if mode == 'start':
            self.best_nll_K_val = torch.inf
        elif mode == 'end':
            pass
        else:
            logging.info(f"-------------------------Start validating: Epoch {epoch}-------------------------")
            # calulate likelihood
            if self.unseen_val_set.shape[0] != 0:
                nll_K_unseen_val, nll_unseen_val, _, _, _ = self.negative_log_likelihood_fn(self.unseen_val_set, keep_quiet=True)
                self.tb_logger.add_scalar('nll_unseen_val', nll_unseen_val, global_step=epoch)
                self.tb_logger.add_scalar('nll_K_unseen_val', nll_K_unseen_val, global_step=epoch)

            if self.unseen_test_set.shape[0] != 0:
                nll_K_unseen_test, nll_unseen_test, _, _, _ = self.negative_log_likelihood_fn(self.unseen_test_set, keep_quiet=True)
                self.tb_logger.add_scalar('nll_K_unseen_test', nll_K_unseen_test, global_step=epoch)

            nll_K_seen_val, nll_seen_val, _, _, _ = self.negative_log_likelihood_fn(self.seen_val_set, keep_quiet=True)
            self.tb_logger.add_scalar('nll_seen_val', nll_seen_val, global_step=epoch)
            self.tb_logger.add_scalar('nll_K_seen_val', nll_K_seen_val, global_step=epoch)
            
            nll_K_seen_test, nll_seen_test, _, _, _ = self.negative_log_likelihood_fn(self.seen_test_set, keep_quiet=True)
            self.tb_logger.add_scalar('nll_K_seen_test', nll_K_seen_test, global_step=epoch)

            nll_K_train, nll_train, loss_part, _, _ = self.negative_log_likelihood_fn(kwargs["batch"], keep_quiet=True)
            self.tb_logger.add_scalar('nll_train', nll_train, global_step=epoch)
            self.tb_logger.add_scalar('nll_K_train', nll_K_train, global_step=epoch)
            logging.info(f'nll_K_train:{nll_K_train:.4f}, nll_train: {nll_train:.4f}, loss part:{loss_part:.4f}.')
 
            nll_K_val, nll_val, loss_part_val, _, _ = self.negative_log_likelihood_fn(self.val_set, keep_quiet=True, return_mean=False)
            self.plot_sample(self.val_set, values=nll_K_val, savefig=f"{epoch}_val")
            self.tb_logger.add_scalar('nll_val', nll_val.mean(), global_step=epoch)
            self.tb_logger.add_scalar('nll_K_val', nll_K_val.mean(), global_step=epoch)
            self.tb_logger.add_scalar('loss_val', loss_part_val.mean(), global_step=epoch)
            logging.info(f'nll_K_val: {nll_K_val.mean() :.4f}, loss_val:{loss_part_val.mean() :.4f}, nll_val: {nll_val.mean() :.4f}.')
            nll_K_val = nll_K_val.mean()

            nll_K_test, nll_test, loss_part_test, _, _ = self.negative_log_likelihood_fn(self.test_set, keep_quiet=True)
            self.tb_logger.add_scalar('nll_K_test', nll_K_test, global_step=epoch)
            logging.info(f'nll_K_test: {nll_K_test:.4f}, loss_test:{loss_part_test :.4f}, nll_test: {nll_test:.4f}')
            
            # save best model
            if nll_K_val < self.best_nll_K_val and epoch >= int(self.config.training.min_checkpoint_epoch_ratio * self.total_epochs):
                self.best_nll_K_val = nll_K_val
                self.best_checkpoint_epoch = epoch
                logging.info(f'Epoch {epoch}: best validation nll {nll_K_val :.4f} on validation set.')
                logging.info(f'Epoch {epoch}: test nll {nll_K_test :.4f} with the best performance on validation set.')
                logging.info(f'Epoch {epoch}: training nll {nll_K_train:.4f} with the best performance on validation set.')
                save_model(self.validate_dir, self.network, name=f"model_best_nll_val.pt")

            # sample
            init = self.manifold.uniform_sample(self.config.sample.sample_num).to(self.device)
            x, x_hist, other_dict = SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                          reverse=True,
                                                          score_net=self.network,
                                                          keep_quiet=True)
            self.plot_sample(x.cpu().numpy(), savefig=f'epoch_{epoch}_generated')
            logging.info("-------------------------End validating.-------------------------")

    def test(self):
        logging.info('Start testing')

        model_path = os.path.join(self.validate_dir, "model_best_nll_val.pt")
        if self.config.if_train and os.path.exists(model_path):
            self.network = load_model(model_path).to(self.device)
        else:
            self.network.to(self.device)

        if self.config.calculate_mesh_nll_earth:
            self.plot_likelihood_earth(grid_N=400)

        logging.info('Calculate likelihood: ')
        nll_K_test, nll_test, _,_,_ = self.negative_log_likelihood_fn(self.test_set, return_mean=False)
        self.plot_sample(self.test_set,  values=nll_K_test, savefig="test")

        nll_K_val, nll_val, _,_,_ = self.negative_log_likelihood_fn(self.val_set, return_mean=False)
        self.plot_sample(self.val_set,  values=nll_K_val, savefig="val")

        nll_K_train, nll_train, _,_,_ = self.negative_log_likelihood_fn(self.training_set, return_mean=False)
        self.plot_sample(self.training_set,  values=nll_K_train, savefig="training")

        nll_K_train_np = nll_K_train.mean().detach().cpu().numpy()
        nll_K_val_np = nll_K_val.mean().detach().cpu().numpy()
        nll_K_test_np = nll_K_test.mean().detach().cpu().numpy()

        if self.config.if_train:
            logging.info(f'On epoch {self.best_checkpoint_epoch}, the best checkpoints: train nll {nll_K_train_np :.4f}, val nll {nll_K_val_np :.4f}, test nll {nll_K_test_np :.4f}.')
        else:
            logging.info(f', the best checkpoints: train nll {nll_K_train_np :.4f}, val nll {nll_K_val_np :.4f}, test nll {nll_K_test_np :.4f}.')

        return

    def sample_on_manifolds(self):
        logging.info(f'Start sampling on manifolds.')
        if self.network is not None:
            self.network.to(self.device)

        # backward
        init = self.manifold.uniform_sample(self.config.sample.sample_num).to(self.device)
        x, x_hist, other_dict = SDE_sampler_manifolds(self.sde, self.manifold, init, 
                                                     reverse=True,
                                                     score_net=self.network)
        self.plot_sample(x.cpu().numpy(), savefig='generated')
        plot_idx = list(range(0, 100, 10)) + list(range(90, 101))
        for i in range(self.sde.N+1):
            if (100 * i / self.sde.N in plot_idx) or (i > self.sde.N - 5):
                self.plot_sample(x_hist[i].cpu().numpy(), savefig=f'generating_bwd_{i}')

        np.save(f"{self.samples_dir}/{self.dataset_name}_samples_generated.npy", x.cpu().numpy())
        np.save(f"{self.samples_dir}/{self.dataset_name}_samples_test_set.npy", self.test_set.numpy())
        return


