import random

import numpy as np
import torch
import torch.nn as nn
from functorch import jacfwd, vmap
from torch.distributions.uniform import Uniform
from torch.utils.data import IterableDataset
from tqdm import tqdm

from .manifold import ImplicitManifold
from .mcmc import ConstrainedLangevinMC, LangevinMC
from .autoencoders import Autoencoder, VariationalAutoencoder


class EBM:
    """An energy-based model.

    Provides sampling and training functionality.

    Args:
        energy: the energy function with which to sample
        lims: a 2-tuple of either floats or tensors defining the boundaries of the data
        device: the device on which the computations will be performed (all networks will be moved
            to this device)
    """

    def __init__(self, energy, lims, device):
        self.device = device
        self.energy = energy.to(device)
        self.mcmc = LangevinMC(energy)

        if hasattr(energy, "dom_shape"):
            self.dom_shape = self.energy.dom_shape
        elif hasattr(energy, "dom_dim"):
            self.dom_shape = (self.energy.dom_dim,)
        else:
            raise ValueError("Provided energy must have attribute for domain dimension or shape")

        self.lims = lims

    @property
    def lims(self):
        return self._lims

    @lims.setter
    def lims(self, value):
        try:
            iter(value[0])
        except TypeError: # Broadcast lims into input shape if input is not a Tensor
            value = (torch.full(self.dom_shape, value[0]), torch.full(self.dom_shape, value[1]))
        self._lims = value

    def init_buffer(self, buffer_size=10000, **noise_kwargs):
        """Initialize the EBM buffer.

        Args:
            buffer_size: the number of samples to store in the buffer
            noise_kwargs: keyword arguments to be passed to self.sample_noise
        """
        self.buffer = self.sample_noise(buffer_size, **noise_kwargs).cpu() # Buffer is stored on CPU
        self.buffer_size = buffer_size

    def _clamp(self, x):
        lims = self.lims[0][None,...].to(x.device), self.lims[1][None,...].to(x.device)
        return torch.maximum(torch.minimum(x, lims[1]), lims[0])

    def partition(self, sample_num=5000):
        # Estimate partition function
        with torch.no_grad():
            uniform_noise = self.sample_noise(sample_num)
            unnorm_probs = torch.exp(-self.energy(uniform_noise))
            latent_volume = torch.prod(self.lims[1] - self.lims[0])
            return torch.mean(unnorm_probs) * latent_volume

    def sample_noise(self, size):
        """Samples noise to initialize new MCMC samples.

        Args:
            size: the number of samples to return

        Returns:
            A tensor of noise
        """
        noise = Uniform(*self.lims).sample((size,))[:,...].to(self.device)
        noise = self._clamp(noise)
        return noise

    def sample(self, size, buffer_frac=0.95, update_buffer=False, noise_kwargs={}, mc_kwargs={}):
        """Sample from this energy-based model.

        Args:
            size: the number of samples to return
            buffer_frac: the percentage of points to initialize from the sample buffer
            update_buffer: whether to update the sample buffer (should be true during training)
            noise_kwargs: dictionary of arguments to be passed to self.sample_noise
            mc_kwargs: dictionary of additional arguments to be passed to self.mcmc.sample_chain

        Returns:
            A tensor of samples
        """
        if (buffer_frac > 0 or update_buffer) and not hasattr(self, "buffer"):
            raise ValueError("Run `ebm.init_buffer()` before using buffer during sampling.")

        num_buffer = np.random.binomial(size, buffer_frac)
        num_rand = size - num_buffer

        if num_rand > 0:
            rand_samples = self.sample_noise(num_rand, **noise_kwargs)
        else:
            rand_samples = torch.empty(0, *self.dom_shape).to(self.device)

        if num_buffer > 0:
            buffer_samples = torch.stack(random.choices(self.buffer, k=num_buffer)).to(self.device)
        else:
            buffer_samples = torch.empty(0, *self.dom_shape).to(self.device)

        init_samples = torch.cat((rand_samples, buffer_samples))
        samples = self.mcmc.sample_chain(init_samples, **mc_kwargs)
        samples = self._clamp(samples)

        if update_buffer:
            self.buffer = torch.cat((samples.cpu(), self.buffer))[:self.buffer_size]

        return samples.detach()

    def train(self, optim, dataloader, epochs, neg_weight=1., beta=1., clip_norm=1.,
              **sample_kwargs):
        """Train this EBM using contrastive divergence.

        Args:
            optim: an optimizer for the parameters of `self.energy`
            dataloader: iterable from which to load training batches
            epochs: number of epochs to train
            beta: the coefficient for the regularizer
            clip_norm: the maximum norm to which gradient will be clipped
            sample_kwargs: additional arguments to be passed into `self.sample`
        """
        self.energy.train()

        for epoch in range(epochs):

            pbar = tqdm(dataloader)
            for batch in pbar:
                if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2:
                    batch, _ = batch

                optim.zero_grad()

                pos_samples = batch.to(self.device)
                neg_samples = self.sample(len(batch), update_buffer=True, **sample_kwargs)

                pos = self.energy(pos_samples)
                neg = self.energy(neg_samples) * neg_weight

                cd_loss = (pos-neg).mean()
                reg_loss = (pos.square() + neg.square()).mean()

                loss = cd_loss + beta * reg_loss
                loss.backward()
                nn.utils.clip_grad_norm_(self.energy.parameters(), clip_norm)
                optim.step()

                pbar.set_description(f"[E{epoch:3d}] loss: {loss.detach():4.5f}")

        self.energy.eval()


class ConstrainedEBM(EBM):
    """A constrained energy-based model.

    Provides sampling and training functionality.

    Args:
        mdf: a manifold-defining function (the manifold is given by M = {x: mdf(x) = 0})
        energy: the energy function with which to sample
        lims: a 2-tuple of either floats or tensors defining the boundaries of the data
        device: the device on which the computations will be performed (all networks will be moved
            to this device)
        buffer_size: the number of samples to keep in the buffer
    """

    def __init__(self, mdf, energy, device, *args, **kwargs):
        super().__init__(energy, *args, device=device, **kwargs)

        self.init_mcmc = self.mcmc # Store simple MCMC method from parent class
        self.mcmc = ConstrainedLangevinMC(mdf, energy) # This class uses constrained MC for sampling
        self.manifold = ImplicitManifold(mdf, device)

    def sample_noise(self, size, ambient_sample=False, opt_steps=30, **mc_kwargs):
        """Samples noise, projected to the manifold.

        This is used to initialize new samples before MCMC.

        Args:
            size: the number of samples to return
            ambient_sample: if True, sample from ambient energy before projecting to manifold
            opt_steps: the maximum number of LBFGS steps for projecting noise to the manifold
            mc_kwargs: dictionary of additional arguments to be passed to
                self.init_mcmc.sample_chain method

        Returns:
            A tensor of noise projected to the manifold
        """
        ambient_noise = super().sample_noise(size)
        if ambient_sample:
            ambient_noise = self.init_mcmc.sample_chain(ambient_noise, **mc_kwargs)
        manifold_noise = self.manifold.project(ambient_noise, opt_steps)
        manifold_noise = self._clamp(manifold_noise)
        return manifold_noise


class AutoencoderEBM:
    """An autoencoder with an EBM in the latent space.

    Provides sampling and training functionality.

    Args:
        encoder: a Map defining the encoder
        decoder: a Map defining the decoder
        energy: the energy function for the latent space
        lims: a 2-tuple of either floats or tensors defining the boundaries of the data
        device: the device on which the computations will be performed (all networks will be moved
            to this device)
    """

    def __init__(self, encoder, decoder, energy, device, infer_lims=True):
        self.device = device
        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)
        self.energy = energy.to(device)

        self._init_encoder(encoder, decoder, device)
        self.ebm = EBM(energy, lims=(-10., 10.), device=device)
        self.infer_lims = infer_lims

    def _init_encoder(self, encoder, decoder, device):
        self.autoencoder = Autoencoder(encoder, decoder, device)

    def init_buffer(self, *args, **kwargs):
        """Initialize buffer for the EBM component"""
        self.ebm.init_buffer(*args, **kwargs)

    def sample(self, *args, **kwargs):
        latent_sample = self.ebm.sample(*args, **kwargs)
        with torch.no_grad():
            return self.decoder(latent_sample)

    def prob(self, x=None, z=None, sample_num=5000):
        """Computes the probability density in n-dimensional ambient space."""
        ambient_energy = self.ambient_energy(x, z)
        partition = self.ebm.partition(sample_num=sample_num)
        return torch.exp(-ambient_energy)/partition

    def ambient_energy(self, x=None, z=None):
        """Computes the energy in n-dimensional ambient space."""
        assert x is not None or z is not None
        if z is None:
            z = self.encoder(x)

        # Compute change-of-volume expression to go from latent to ambient energy
        single_z_decoder = lambda z: self.decoder(z[None, ...]).squeeze()
        jac = vmap(jacfwd(single_z_decoder))(z)
        jtj = torch.bmm(jac.transpose(1, 2), jac)
        cholesky_factor = torch.linalg.cholesky(jtj)
        cholesky_diagonal = torch.diagonal(cholesky_factor, dim1=1, dim2=2)
        half_log_det_jtj = torch.sum(torch.log(cholesky_diagonal), dim=1, keepdim=True)

        return self.energy(z) + half_log_det_jtj

    def train_ae(self, *args, **kwargs):
        """Train the autoencoder component of the model."""
        self.autoencoder.train(*args, **kwargs)

    def train_ebm(self, optim, dataloader, epochs, beta=1., **sample_kwargs):
        """Fit the EBM component of the model.

        Args:
            optim: an optimizer for the parameters of `self.energy`
            dataloader: iterable from which to load training batches
            epochs: number of epochs to train
            beta: the coefficient for the regularizer
            sample_kwargs: additional arguments to be passed into `self.sample`
        """

        encoder = self.encoder
        # Create latent dataloader
        class LatentLoader:
            def __iter__(self):
                for batch in dataloader:
                    yield encoder(batch)

        latent_loader = LatentLoader()

        if self.infer_lims: # Update lims for EBM
            latent_sample = next(iter(latent_loader))
            min_input = latent_sample[0]
            max_input = latent_sample[0]

            for latent_batch in latent_loader:
                for latent in latent_batch:
                    min_input = torch.minimum(min_input, latent)
                    max_input = torch.maximum(max_input, latent)

            pad = 0.5 * (max_input - min_input) # Pad lims on each side
            self.ebm.lims = (min_input - pad, max_input + pad)

        return self.ebm.train(optim, latent_loader, epochs, beta=beta, **sample_kwargs)


class VariationalAutoencoderEBM(AutoencoderEBM):
    """A VAE with an EBM in the latent space.

    Provides sampling and training functionality.

    Args:
        encoder: a Map defining the encoder (map to space of dimension 2*latent_dim)
        decoder: a Map defining the decoder
        energy: the energy function for the latent space
        lims: a 2-tuple of either floats or tensors defining the boundaries of the data
        device: the device on which the computations will be performed (all networks will be moved
            to this device)
    """
    def _init_encoder(self, encoder, decoder, device):
        self.autoencoder = VariationalAutoencoder(encoder, decoder, device)
        self.encoder = self.autoencoder.encoder # `self.encoder` only outputs the mean
