# Code adapted from https://github.com/lollcat/fab-torch/blob/master/fab/target_distributions/double_well.py

import numpy as np
import torch
import torch.nn as nn

from .target_distribution import TargetDistribution
from ..samplers.rejection_sampling import rejection_sampling


def double_well_energy_dim_1(x_1, a=-0.5, b=-6.0, c=1.0):
    return a * x_1 + b * x_1.pow(2) + c * x_1.pow(4)

def double_well_energy_dim_2(x_2):
    return 0.5 * x_2.pow(2)

def double_well_energy(
    x : torch.Tensor,
    a : float = -0.5, 
    b : float = -6.0,
    c : float = 1.0
) -> torch.Tensor:
    """
    Compute the energy of the double well potential.

    Arguments:
        x: (batch_size, 2) tensor of samples
        a, b, c: parameters of the double well potential

    Returns:
        energy: (batch_size, 1) tensor of energies
    """
    x_1 = x[:, 0]
    x_2 = x[:, 1]
    e1 = double_well_energy_dim_1(x_1, a, b, c)
    e2 = double_well_energy_dim_2(x_2)
    return e1 + e2

class DoubleWell(TargetDistribution):
    def __init__(self):
        super().__init__()
        # Define energy params
        self._dim = 2
        self._a = -0.5
        self._b = -6.0
        self._c = 1.0

        # Define proposal params
        self.register_buffer("component_mix", torch.tensor([0.2, 0.8]))
        self.register_buffer("means", torch.tensor([-1.7, 1.7]))
        self.register_buffer("scales", torch.tensor([0.5, 0.5]))
    
    @property
    def dim(self):
        return self._dim

    def energy(self, x):
        return double_well_energy(x, a=self._a, b=self._b, c=self._c).unsqueeze(-1)

    def log_prob(self, x):
        return -self.energy(x)

    def _sample_first_dimension(
        self,
        num_samples : int
    ) -> torch.Tensor:
        # Define target.

        class TargetLogProb(nn.Module):
            def forward(self, x):
                return -x ** 4 + 6 * x ** 2 + 1 / 2 * x
            def log_prob(self, x):
                return self.forward(x)

        TARGET_Z = 11784.50927

        # Define proposal
        mix = torch.distributions.Categorical(self.component_mix)
        com = torch.distributions.Normal(self.means, self.scales)

        proposal = torch.distributions.MixtureSameFamily(mixture_distribution=mix,
                                                            component_distribution=com)

        k = TARGET_Z * 3

        samples = rejection_sampling(TargetLogProb(), num_samples, proposal, k)
        return samples


    def sample(
        self,
        num_samples : int
    ) -> torch.Tensor: 
        if self._a == -0.5 and self._b == -6 and self._c == 1.0:
            dim1_samples = self._sample_first_dimension(num_samples)
            dim2_samples = torch.distributions.Normal(
                torch.tensor(0.0).to(dim1_samples.device),
                torch.tensor(1.0).to(dim1_samples.device)
            ).sample((num_samples,))
            return torch.stack([dim1_samples, dim2_samples], dim=-1)
        else:
            raise NotImplementedError

    @property
    def log_Z(self):
        log_Z_dim0 = np.log(11784.50927)
        log_Z_dim1 = 0.5 * np.log(2 * torch.pi)
        return log_Z_dim0 + log_Z_dim1

# if __name__ == '__main__':
#     # Test that rejection sampling is work as desired.
#     import matplotlib.pyplot as plt
#     target = DoubleWellEnergy(2)


#     x_linspace = torch.linspace(-4, 4, 200)

#     Z_dim_1 = 11784.50927
#     samples = target.sample((10000,))
#     p_1 = torch.exp(-target._energy_dim_1(x_linspace))
#     # plot first dimension vs normalised log prob
#     plt.plot(x_linspace, p_1/Z_dim_1, label="p_1 normalised")
#     plt.hist(samples[:, 0], density=True, bins=100, label="sample density")
#     plt.legend()
#     plt.show()

#     # Now dimension 2.
#     Z_dim_2 = (2 * torch.pi)**0.5
#     p_2 = torch.exp(-target._energy_dim_2(x_linspace))
#     # plot first dimension vs normalised log prob
#     plt.plot(x_linspace, p_2/Z_dim_2, label="p_2 normalised")
#     plt.hist(samples[:, 1], density=True, bins=100, label="sample density")
#     plt.legend()
#     plt.show()