import distrax
import jax
import jax.numpy as jnp
import chex
import jax.random as random
import matplotlib
import numpy as np
import numpyro.distributions as dist
import wandb
from jax._src.scipy.special import logsumexp
from scipy.stats import wishart
from matplotlib import pyplot as plt
from targets.base_target import Target
from typing import List

from utils.path_utils import project_path


# matplotlib.use('agg')


class GaussianMixtureModel(Target):
    def __init__(self, dim=3, num_components=15, seed=0):
        super().__init__(dim=dim, log_Z=0, can_sample=True)

        self.num_components = num_components

        # parameters
        min_mean_val = -7
        max_mean_val = 7
        min_val_mixture_weight = 0.4
        max_val_mixture_weight = 0.6
        degree_of_freedom_wishart = dim + 2

        seed = jax.random.PRNGKey(seed)

        # set mixture components
        locs = jax.random.uniform(seed, minval=min_mean_val, maxval=max_mean_val, shape=(num_components, dim))
        locs = jnp.array([[-4, -4], [-4, 4], [4, -4], [4, 4]])
        covariances = []
        for _ in range(num_components):
            seed, subkey = random.split(seed)

            # Set the random seed for Scipy
            seed_value = random.randint(key=subkey, shape=(), minval=0, maxval=2 ** 30)
            np.random.seed(seed_value)

            cov_matrix = wishart.rvs(df=degree_of_freedom_wishart, scale=jnp.eye(dim))
            covariances.append(cov_matrix)

        self.component_dist = distrax.MultivariateNormalFullCovariance(locs, jnp.array(covariances))

        # set mixture weights
        uniform_mws = True
        if uniform_mws:
            mixture_weights = distrax.Categorical(logits=jnp.ones(num_components) / num_components)
        else:
            mixture_weights = distrax.Categorical(logits=dist.Uniform(
                low=min_val_mixture_weight, high=max_val_mixture_weight).sample(seed, sample_shape=(num_components,)))

        self.mixture_distribution = distrax.MixtureSameFamily(mixture_distribution=mixture_weights,
                                                              components_distribution=self.component_dist)

    def sample(self, seed: chex.PRNGKey, sample_shape: chex.Shape) -> chex.Array:
        return self.mixture_distribution.sample(seed=seed, sample_shape=sample_shape)

    def log_prob(self, x: chex.Array) -> chex.Array:
        batched = x.ndim == 2

        if not batched:
            x = x[None,]

        log_prob = self.mixture_distribution.log_prob(x)

        if not batched:
            log_prob = jnp.squeeze(log_prob, axis=0)

        return log_prob

    def entropy(self, samples: chex.Array = None):
        expanded = jnp.expand_dims(samples, axis=-2)
        # Compute `log_prob` in every component.
        idx = jnp.argmax(self.mixture_distribution.components_distribution.log_prob(expanded), 1)
        unique_elements, counts = jnp.unique(idx, return_counts=True)
        mode_dist = counts / samples.shape[0]
        entropy = -jnp.sum(mode_dist * (jnp.log(mode_dist) / jnp.log(self.num_components)))
        return entropy

    def visualise(self, samples: chex.Array = None, axes: List[plt.Axes] = None, show=False, clip=False) -> None:
        plt.clf()

        boarder = [-9, 9]
        # clipping samples because of FABs outlier
        if clip:
            samples = jnp.clip(samples, boarder[0], boarder[1])

        if self.dim == 2:
            fig = plt.figure()
            ax = fig.add_subplot()

            x, y = jnp.meshgrid(jnp.linspace(boarder[0], boarder[1], 1000),
                                jnp.linspace(boarder[0], boarder[1], 1000))
            grid = jnp.c_[x.ravel(), y.ravel()]
            pdf_values = jax.vmap(jnp.exp)(self.log_prob(grid))
            expanded = jnp.expand_dims(grid, axis=-2)
            # Compute `log_prob` in every component.
            per_comp_pdf = self.mixture_distribution.components_distribution.log_prob(expanded)
            gating = jnp.exp(
                self.mixture_distribution.components_distribution.log_prob(expanded) - logsumexp(per_comp_pdf, 1)[:,
                                                                                       None])
            gating = jnp.argmax(gating, 1).reshape(x.shape)

            pdf_values = jnp.reshape(pdf_values, x.shape)
            # ax.contourf(x, y, pdf_values, levels=50)  # , cmap='viridis')
            ax.contourf(x, y, pdf_values, levels=50)  # , cmap='viridis')
            ax.contour(x, y, gating, colors='w')  # , cmap='viridis')
            if samples is not None:
                plt.scatter(samples[:300, 0], samples[:300, 1], c='r', alpha=0.6, marker='o')
            plt.annotate('$\\xi_1$', (-8, 6), color='w', fontsize=20)
            plt.annotate('$\\xi_2$', (7, 6), color='w', fontsize=20)
            plt.annotate('$\\xi_3$', (-8, -6), color='w', fontsize=20)
            plt.annotate('$\\xi_4$', (7, -6), color='w', fontsize=20)
            # plt.xlabel('X')
            # plt.ylabel('Y')
            # plt.colorbar()
            plt.xticks([])
            plt.yticks([])
            plt.xlim(boarder)
            plt.ylim(boarder)
            plt.axis('off')
            # plt.savefig(os.path.join(project_path('./figures/'), f"gmm2D.pdf"), bbox_inches='tight', pad_inches=0.1)

            try:
                wandb.log({"images/target_vis": wandb.Image(plt)})
            except:
                pass

            # import tikzplotlib
            import os
            plt.savefig(os.path.join(project_path('./figures/'), f"gmm.png"), bbox_inches='tight', pad_inches=0.1)
            # tikzplotlib.save(os.path.join(project_path('./figures/'), f"gmm.tex"))

        else:
            target_samples = self.sample(jax.random.PRNGKey(0), (500,))
            plt.scatter(target_samples[:, 0], target_samples[:, 1], c='b', label='target')
            plt.scatter(samples[:, 0], samples[:, 1], c='r', label='model')
            plt.legend()

            try:
                wandb.log({"images/target_vis": wandb.Image(plt)})
            except:
                pass

        if show:
            plt.show()

        plt.close()


if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    gmm = GaussianMixtureModel(dim=2, num_components=4, seed=1)
    samples = gmm.component_dist.sample(seed=key, sample_shape=(200,))
    avg_gaussian = jax.random.normal(key, (300, 2)) * 5 - 1.5
    sample = gmm.sample(key, (300,))
    # print(sample)
    # print(samples)
    # print((gmm.log_prob(sample)).shape)
    # print((jax.vmap(gmm.log_prob)(sample)).shape)
    # gmm.visualise(samples=avg_gaussian.clip(min=-13.5, max=8.5), show=True)  # t1
    gmm.visualise(samples=samples[:, 0].reshape(-1, 2), show=True)  # t2
    print(gmm.entropy(samples[:,0]))
    gmm.visualise(samples=samples[:, :2].reshape(-1, 2), show=True)  # t3
    print(gmm.entropy(samples[:,:2].reshape(-1, 2)))
    gmm.visualise(samples=samples[:, :].reshape(-1, 2), show=True)  # t4
    print(gmm.entropy(samples[:,:].reshape(-1, 2)))

    # f_elbo = lambda x: -1 / x
    # f_eubo = lambda x, a, b: jnp.exp(-a * x) * jnp.sin(b * x) * 25 + 0.2
    # ln_z = lambda x: jnp.zeros_like(x)
    # x = jnp.linspace(0.1, 6, 100)
    # x_long = jnp.linspace(-0.1, 20, 100)
    # plt.plot(x, f_elbo(x), c='b')
    # plt.annotate(text='ELBO', xy=(5, -2), c='b')
    # plt.plot(x, f_eubo(x, 1, 0.5), c='r')
    # plt.annotate(text='EUBO', xy=(5, 2), c='r')
    # plt.plot(x_long, ln_z(x_long), c='k')
    # plt.xlim([0, 6])
    # plt.ylim([-5, 5])
    # plt.xticks([])
    # plt.yticks([])
    # plt.annotate(text='$\log Z$', xy=(6., 0.), c='k')
    # import tikzplotlib
    # import os
    # tikzplotlib.save(os.path.join(project_path('./figures/'), f"elbo_eubo.tex"))
    # plt.show()
