# import logging
import numpy as np
import torch
import torch.optim as optim
import time
import pytorch_lightning as pl
from sklearn.decomposition import PCA
import src.models.generate_NN as g_NN
import src.datamodules.generate_data as g_data
import src.utils.plot_utils as PLU
import jacinle.io as io
from torch.optim.lr_scheduler import StepLR as StepLR
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.neighbors import KernelDensity


class Base_Tmap(pl.LightningModule):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def map_forward(self, input_data):
        return self.map_t(input_data)

    def map_forward_nograd(self, input_data):
        return self.map_t(input_data)

    def nn_loss(self):
        return 0

    def configure_optimizers(self):
        optimizer_map = optim.Adam(self.map_t.parameters(), lr=self.cfg.LR_g)
        optimizer_h = optim.Adam(self.h.parameters(), lr=self.cfg.LR_h)
        if self.cfg.schedule_learning_rate:
            return [optimizer_map, optimizer_h], [StepLR(optimizer_map, step_size=self.cfg.lr_schedule_per_epoch, gamma=self.cfg.lr_schedule_scale), StepLR(optimizer_h, step_size=self.cfg.lr_schedule_per_epoch, gamma=self.cfg.lr_schedule_scale)]
        else:
            return optimizer_map, optimizer_h

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.cfg.BATCH_SIZE, shuffle=True, num_workers=4)

    def update_mu(self, P_k):
        mean_mu = P_k.mean(axis=0)
        cov_mu = (
            P_k - mean_mu).T @ (P_k - mean_mu) / (P_k.shape[0] - 1)
        print("check mean mu max, cov_mu max", mean_mu.max(), cov_mu.max())

        self.train_data[:, :, 1] = g_data.torch_samples_generate_Gaussian(
            self.cfg.N_TRAIN_SAMPLES, mean_mu, cov_mu)
        inv_cov_mu = torch.inverse(cov_mu)
        self.density_mu = [mean_mu, inv_cov_mu]
        if self.cfg.type_data == 'Gauss2Gauss':
            log10symkl = self.calculate_log10symkl(mean_mu.cpu(), cov_mu.cpu(), inv_cov_mu.cpu())
            io.dump(self.P_save_path + f'/log10symkl_{self.k}.pt', log10symkl)

    def calculate_log10symkl(self, mean_mu, cov_mu, inv_cov_mu):
        dim = cov_mu.shape[0]
        mean_q, inv_cov_q = self.density_q[0].detach().cpu(), self.density_q[1].detach().cpu()
        cov_q = torch.inverse(inv_cov_q)
        mean_mu = mean_mu.reshape(-1, 1)
        mean_q = mean_q.reshape(-1, 1)

        t_now = self.k * self.cfg.step_a
        exp2m = torch.matrix_exp(-2 * inv_cov_q * t_now)
        expm = torch.matrix_exp(-inv_cov_q * t_now)

        mean_now = (torch.eye(dim) - expm) @ mean_q
        cov_now = cov_q @ (torch.eye(dim) - exp2m) + exp2m
        inv_cov_now = torch.inverse(cov_now)

        kl_mu_gt = self.kl_gaussian(mean_mu, cov_mu, mean_now, cov_now, inv_cov_now)
        kl_gt_mu = self.kl_gaussian(mean_now, cov_now, mean_mu, cov_mu, inv_cov_mu)
        return torch.log10(kl_mu_gt + kl_gt_mu)

    def kl_gaussian(self, mean_p, cov_p, mean_q, cov_q, inv_cov_q):
        dim = cov_p.shape[0]
        part1 = torch.log(torch.det(cov_q) / torch.det(cov_p)) - dim
        part2 = (mean_p - mean_q).T @ inv_cov_q @ (mean_p - mean_q)
        part3 = torch.trace(inv_cov_q @ cov_p)
        kl_between_p_q = (part1 + part2 + part3) / 2
        return kl_between_p_q

    def plot_single_h_2d(self, x_grid, y_grid, xy_list, ax_param):
        h_values = self.h(xy_list).detach().cpu().numpy().reshape(ax_param.num_grid, -1)
        PLU.surface_alone(
            x_grid, y_grid, h_values, ax_params=ax_param,
            save_path=self.image_save_path + f'/epoch{self.epoch}_h_surface.png')

    def print_y_psf(self, Tx):
        y_psf_mean_epoch = Tx.mean(dim=0)
        y_psf_var_epoch = (
            Tx - y_psf_mean_epoch).T @ (Tx - y_psf_mean_epoch) / (self.train_data[:, :, 0].shape[0] - 1)

        self.trainer.logger.experiment.add_histogram(
            tag='y_psf_mean_epoch', values=y_psf_mean_epoch)
        self.trainer.logger.experiment.add_histogram(
            tag='y_psf_var_epoch', values=y_psf_var_epoch)

        print('y_psf_mean_epoch', y_psf_mean_epoch)
        print('y_psf_var_epoch', y_psf_var_epoch)

    def _get_xy_grid(self):
        x_grid, y_grid, xy_list = PLU.grid_NN_2_generator(
            self.handle.ax_param.num_grid, self.handle.ax_param.left_place, self.handle.ax_param.right_place)
        return x_grid, y_grid, xy_list


class Gaussian_Tmap(Base_Tmap):
    def __init__(self, cfg, density_q, density_mu, train_data, P_save_path, image_save_path=None):
        super().__init__()
        self.cfg = cfg
        assert cfg.N_outer_ITERS == 1 or cfg.N_inner_ITERS == 1

        self.density_q = density_q
        self.density_mu = density_mu
        self.gmm_mu_estimator = None

        self.h, self.map_t = g_NN.generate_monge_NN(cfg)
        self.epoch = 1
        self.k = 1
        self.P_save_path = P_save_path
        self.image_save_path = image_save_path
        # train_data [:,:,0] is P_k, train_data [:,:,1] is \mu
        self.train_data = train_data
        self.pca = PCA(n_components=2, random_state=1)

        if self.cfg.log_kl:
            self._loss_h_function = self.log_kl_loss_h
        elif self.cfg.exp_h:
            if self.cfg.exp_h_add_small:
                self._loss_h_function = self.revised_exp_h_loss_h
            else:
                self._loss_h_function = self.exp_h_loss_h
        else:
            self._loss_h_function = self.basic_loss_h

        self.automatic_optimization = False
        self.handle = PLU.DIM2_PLOT(
            left_place=-self.cfg.plot_size,
            right_place=self.cfg.plot_size, bandwidth=self.cfg.band_width)

        self.t_tr_strt, self.t_tr_elaps = 0.0, 0.0

    def revised_exp_h_loss_h(self, y_data, Tx):
        E_h_mu = (self.h(y_data)).mean()
        loss_h = E_h_mu - torch.log(self.h(Tx) + 1e-10).mean()
        return loss_h

    def exp_h_loss_h(self, y_data, Tx):
        E_h_mu = (self.h(y_data)).mean()
        loss_h = E_h_mu - torch.log(self.h(Tx)).mean()
        return loss_h

    def log_kl_loss_h(self, y_data, Tx):
        log_E_exp_h = torch.log(torch.exp(self.h(y_data)).mean())
        loss_h = log_E_exp_h - self.h(Tx).mean()
        return loss_h

    def basic_loss_h(self, y_data, Tx):
        E_exp_h = torch.exp(self.h(y_data)).mean()
        loss_h = E_exp_h - self.h(Tx).mean()
        return loss_h

    def on_train_epoch_start(self):
        self.t_tr_strt = time.perf_counter()

    def training_step(self, real_data, batch_idx):
        optimizer_t, optimizer_h = self.optimizers()
        x_i = real_data[:, :, 0]
        y_data = real_data[:, :, 1]

        h_ot_loss_value_batch = 0

        optimizer_t.zero_grad()
        ######################################################
        #                Inner Loop Begin                   #
        ######################################################
        for inner_iter in range(1, self.cfg.N_inner_ITERS + 1):
            optimizer_h.zero_grad()

            Tx = x_i if self.cfg.debug_h else self.map_forward(x_i)
            loss_h = self._loss_h_function(y_data, Tx)
            h_ot_loss_value_batch += loss_h
            if self.cfg.debug_h:
                loss_h.backward()
            else:
                loss_h.backward(retain_graph=True)

            # ! update h: try using more updates for map
            if (batch_idx + 1) % self.cfg.N_outer_ITERS == 0:
                optimizer_h.step()

            # Just for the last iteration keep the gradient on f intact
            if inner_iter != self.cfg.N_inner_ITERS:
                optimizer_t.zero_grad()

        ######################################################
        #                Inner Loop Ends                     #
        ######################################################
        if self.cfg.debug_h == False:
            w2_loss = -(x_i - Tx).pow(2).sum(dim=1).mean() / 2 / self.cfg.step_a
            if self.cfg.mu_equal_q:
                remaining_g_loss = w2_loss
                log_loss = torch.zeros_like(w2_loss)
            else:
                log_loss = self.log_total_loss(Tx)
                remaining_g_loss = log_loss + w2_loss

            remaining_g_loss.backward()
            for p in list(self.map_t.parameters()):
                p.grad.copy_(-p.grad)

            # ! update T
            optimizer_t.step()
        else:
            w2_loss, log_loss, remaining_g_loss = 0, 0, 0

        h_ot_loss_value_batch /= (self.cfg.N_inner_ITERS)
        # total_loss_value_batch = h_ot_loss_value_batch + remaining_g_loss

        self.log_dict(
            {'h_ot_loss': h_ot_loss_value_batch, 'log_loss': log_loss, 'w2_loss': w2_loss}, prog_bar=True)
        return None

    def log_q_loss(self, Tx):
        return self.gaussian_energy(Tx, self.density_q)

    def gaussian_energy(self, Tx, density):
        mean, inv_cov = density
        mu_centered = Tx - mean
        log_auxili_gauss = -((mu_centered @ inv_cov) *
                             mu_centered).sum(axis=1).mean() / 2
        return log_auxili_gauss

    def log_mu_loss(self, Tx):
        return -self.gaussian_energy(Tx, self.density_mu)

    def log_total_loss(self, Tx):
        log_auxili_gauss = self.log_mu_loss(Tx)
        log_q = self.log_q_loss(Tx)
        log_loss = log_auxili_gauss + log_q
        return log_loss

    def on_train_epoch_end(self):
        Tx = self.map_forward(self.train_data[:, :, 0].cuda())
        # self.print_y_psf(Tx)

        if self.cfg.mu_equal_q == False and self.epoch % self.cfg.epochs == 0:
            self.iterate_dataloader(Tx)
        self.epoch += 1
        if (self.k - 1) % 2 == 0:
            self.t_tr_elaps = time.perf_counter() - self.t_tr_strt
            print("time=", self.t_tr_elaps, "dim=", self.cfg.INPUT_DIM)
            io.dump(self.P_save_path + f'/train_time_{self.k-1}.pt', self.t_tr_elaps)

    def iterate_dataloader(self, P_k):
        torch.save(self.map_t.state_dict(), self.P_save_path + f'/map_{self.k}.pt')
        # TODO new_generator
        self.train_data[:, :, 0] = P_k
        self.update_mu(P_k)
        self.k += 1
        io.dump(self.P_save_path + f'/P_{self.k-1}.pt', P_k)


class GM_Tmap(Gaussian_Tmap):
    def gauss_density(self, Tx, mean, cov):
        return torch.exp(
            -(((Tx - mean) @ torch.inverse(cov)) *
              (Tx - mean)).sum(axis=1) / 2) / torch.det(2 * np.pi * cov).sqrt()

    def gm_density(self, Tx, density, n_comp=None):
        mean_q_list, cov_q, weights = density
        q_density = 0
        for idx in range(n_comp):
            q_density += weights[idx] * self.gauss_density(Tx, mean_q_list[idx], cov_q[idx])
        q_density += 1e-10
        return q_density

    def log_q_loss(self, Tx):
        return self.gm_density(Tx, self.density_q, self.cfg.NUM_GMM_COMPONENT[1]).log().mean()

    def on_train_epoch_start(self):
        if self.epoch == 1:
            target_samples = self.sampler_q(1000)
            self.plot_highD_Pk(target_samples)

    def sampler_q(self, num_samples):
        return g_data.torch_samples_generate_GM(
            num_samples, self.density_q[0].detach().cpu(), self.density_q[1].detach().cpu())

    def on_train_epoch_end(self):
        if self.cfg.debug_h:
            self.debug_plot(self.train_data[:, :, 0])
        else:
            P_k = self.map_forward_nograd(self.train_data[:, :, 0].cuda())
            self._basic_plot(P_k[:10000])

        if self.cfg.mu_equal_q == False and self.epoch % self.cfg.epochs == 0:
            self.iterate_dataloader(P_k)
        self.epoch += 1

    def _basic_plot(self, P_k):
        if self.cfg.INPUT_DIM == 1:
            self.plot_1d(P_k)
        elif self.cfg.INPUT_DIM == 2:
            if self.cfg.type_data == 'GM':
                self.plot_2d(P_k)
            elif self.cfg.type_data == 'two_moons':
                self.plot_scatter(P_k)
        else:
            self.plot_highD(P_k)

    def plot_1d(self, P_k):
        PLU.GM_density_1d(self.cfg, self.density_q)
        _, Pk_kde = PLU.sample2density_1d(
            self.cfg, P_k.detach().cpu(), self.image_save_path, self.epoch,
            save=True, energy_or_density='density')
        x_axis = torch.linspace(-self.cfg.plot_size, self.cfg.plot_size, 100).reshape(-1, 1).cuda()
        x_plot = x_axis.detach().cpu()
        plt.plot(x_plot, self.h(x_axis).detach().cpu(), c='C0')

        mean_p, inv_cov_p = self.density_mu[0], self.density_mu[1]
        p_centered = x_axis - mean_p
        log_mu = -((p_centered @ inv_cov_p) *
                   p_centered).sum(axis=1) / 2

        plt.plot(x_plot, torch.clamp_max(Pk_kde /
                                         (log_mu.exp() / torch.sqrt(2 * np.pi * torch.det(1 / inv_cov_p)) + 1e-10).detach().cpu(), 100), c='C1')
        PLU.save_fig(self.image_save_path + f'/epoch{self.epoch}_h.png')

    def plot_2d(self, P_k):
        x_grid, y_grid, xy_list, Pk_density = self.handle.contour_from_sample(
            P_k.detach().cpu())
        self.handle.scatter(
            P_k.detach().cpu()[:4000], self.image_save_path +
            f'/epoch{self.epoch}_my.png', new_fig=False)

        self.handle.ax_param.new_fig = True
        self.plot_h_comparison_2d(x_grid, y_grid, xy_list, Pk_density, self.handle.ax_param)

        if self.epoch == 1:
            target_samples = self.sampler_q(4000)
            self.handle.scatter(target_samples)
            PLU.GM_density_2d(
                self.cfg, self.density_q,
                self.image_save_path + f'/target.png')

    def plot_scatter(self, P_k):
        P_k_cpu = P_k.detach().cpu()
        PLU.sns_scatter_handle(
            P_k_cpu[:4000], -self.cfg.plot_size, self.cfg.plot_size, self.image_save_path +
            f'/epoch{self.epoch}_my.png', opacity=0.4)

        x_grid, y_grid, xy_list, Pk_density = PLU.sample2density_2d(P_k_cpu, self.handle.ax_param)
        self.plot_h_comparison_2d(x_grid, y_grid, xy_list, Pk_density, self.handle.ax_param)

    def plot_highD(self, P_k):
        # TODO I turn off plotting h and go to solve for the criteria.
        # self.plot_highD_Pk(P_k)
        # logging.info(f'epoch:{self.epoch}, h loss: {xxxxx} ')

        x_grid, y_grid, xy_list_highD, Pk_cpu = self.plot_highD_Pk(P_k)
        kde = KernelDensity(
            kernel='gaussian', bandwidth=self.handle.ax_param.bandwidth).fit(Pk_cpu)
        Pk_density_highD = np.exp(kde.score_samples(xy_list_highD)).reshape(100, -1)
        if self.cfg.exp_h:
            self.plot_h_comparison_2d(x_grid, y_grid, xy_list_highD,
                                      Pk_density_highD, self.handle.ax_param)

    def plot_h_comparison_2d(self, x_grid, y_grid, xy_list, Pk_density, ax_param=None):
        mu_density_plot = self._get_mu_density_plot(xy_list)

        Pk_over_mu = np.clip(Pk_density / (mu_density_plot + 1e-10), 0, 100)
        PLU.surface_alone(
            x_grid, y_grid, Pk_over_mu, ax_params=ax_param,
            save_path=self.image_save_path + f'/epoch{self.epoch}_Pk_mu_surface.png')

        xy_list = torch.from_numpy(xy_list).cuda().float()
        self.plot_single_h_2d(x_grid, y_grid, xy_list, ax_param)

    def plot_highD_Pk(self, P_k, before_epoch="after"):
        if self.epoch == 1:
            target_samples = self.sampler_q(10000)
            X_targ_pca = self.pca.fit_transform(target_samples)[:4000]
        else:
            X_targ_pca = self.pca.transform(self.sampler_q(4000))

        Pk_cpu = P_k.detach().cpu()
        X_trf_pca = self.pca.transform(Pk_cpu)[:4000]
        fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(11, 5))

        sns.kdeplot(x=X_targ_pca[:, 0], y=X_targ_pca[:, 1], ax=ax[0],
                    color='mediumaquamarine', linewidths=3.0, levels=10, alpha=0.7)
        xlims = (-12, 12)
        ylims = (-12, 12)
        ax[0].set_xlim(xlims)
        ax[0].set_ylim(ylims)
        ax[0].set_title('Stationary measure', fontsize=18)
        ax[0].scatter(X_targ_pca[:, 0], X_targ_pca[:, 1], color='darkslategray', alpha=0.1)

        sns.kdeplot(x=X_trf_pca[:, 0], y=X_trf_pca[:, 1], ax=ax[1],
                    color='darkturquoise', linewidths=3.0, levels=10, alpha=0.7)
        ax[1].set_xlim(xlims)
        ax[1].set_ylim(ylims)
        ax[1].set_title('Fitted measure (ours)', fontsize=18)
        ax[1].scatter(X_trf_pca[:, 0], X_trf_pca[:, 1], color='darkslategray', alpha=0.1,)

        fig.savefig(self.image_save_path +
                    f'/epoch{self.epoch}_my_{before_epoch}.png', bbox_inches='tight', dpi=200)
        plt.close()
        x_grid, y_grid, xy_list = self._get_xy_grid()
        xy_list_highD = self.pca.inverse_transform(xy_list)
        return x_grid, y_grid, xy_list_highD, Pk_cpu

    def debug_plot_h(self, x_grid, y_grid, xy_list, ax_param=None):
        if self.epoch == 1:
            mu_density_plot = self._get_mu_density_plot(xy_list)
            xy_list = torch.from_numpy(xy_list).cuda().float()
            gm_density = self.gm_density(
                xy_list, self.density_q, self.cfg.NUM_GMM_COMPONENT[1]).reshape(100, -1).detach().cpu().numpy() / (2 * np.pi)**(self.cfg.INPUT_DIM / 2)
            Pk_over_mu = np.clip(gm_density / (mu_density_plot + 1e-10), 0, 100)
            PLU.contour_alone(
                x_grid, y_grid, Pk_over_mu, ax_params=ax_param,
                save_path=self.image_save_path + f'/epoch{self.epoch}_Pk_mu_contour.png')
            PLU.surface_alone(
                x_grid, y_grid, Pk_over_mu, ax_params=ax_param,
                save_path=self.image_save_path + f'/epoch{self.epoch}_Pk_mu_surface.png')
        else:
            xy_list = torch.from_numpy(xy_list).cuda().float()
        self.plot_single_h_2d(x_grid, y_grid, xy_list, ax_param)

    def debug_plot(self, P_k):
        # * this is only for highD now
        if self.epoch == 1:
            x_grid, y_grid, xy_list_highD, _ = self.plot_highD_Pk(P_k)
        else:
            x_grid, y_grid, xy_list = PLU.grid_NN_2_generator(
                self.handle.ax_param.num_grid, self.handle.ax_param.left_place, self.handle.ax_param.right_place)
            xy_list_highD = self.pca.inverse_transform(xy_list)
        self.debug_plot_h(x_grid, y_grid, xy_list_highD, self.handle.ax_param)

    def _get_mu_density_plot(self, xy_list):
        pos_n_n_2 = PLU.grid_N_N_2_generator(100, -self.cfg.plot_size, self.cfg.plot_size)
        # if self.epoch > 1:
        #     mu_density_plot = np.exp(self.gmm_mu_estimator.score_samples(xy_list)).reshape(100, -1)
        # else:
        if self.cfg.INPUT_DIM > 2:
            mu_density_plot = g_data.np_PDF_generate_multi_normal(
                xy_list, self.density_mu[0].detach().cpu(), torch.inverse(self.density_mu[1].detach().cpu())).reshape(100, -1)
        else:
            mu_density_plot = g_data.np_PDF_generate_multi_normal(
                pos_n_n_2, self.density_mu[0].detach().cpu(), torch.inverse(self.density_mu[1].detach().cpu()))
        return mu_density_plot

    # from sklearn.mixture import GaussianMixture as GM
    # import src.utils.pytorch_utils as PTU
    # def log_mu_loss(self, Tx):
    #     if self.epoch > 1:
    #         return -self.gm_density(Tx, self.density_mu, self.gmm_mu_estimator.means_.shape[0]).log().mean()
    #     else:
    #         return -self.gaussian_energy(Tx, self.density_mu)

    # def update_mu(self, P_k):
    #     # TODO later we need to change the tuning of number of components, we could use estimator.aic the lower the better
    #     self.gmm_mu_estimator = GM(
    #         n_components=6, covariance_type='full',
    #         random_state=0).fit(P_k.cpu())
    #     mean_mu = PTU.numpy2torch(self.gmm_mu_estimator.means_).cuda()
    #     cov_mu = PTU.numpy2torch(self.gmm_mu_estimator.covariances_).cuda()
    #     weights = PTU.numpy2torch(self.gmm_mu_estimator.weights_).cuda()
    #     self.density_mu = [mean_mu, cov_mu, weights]
    #     self.train_data[:, :, 1] = PTU.numpy2torch(
    #         self.gmm_mu_estimator.sample(self.cfg.N_TRAIN_SAMPLES)[0]).cuda()
