from typing import Callable
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.distributions as D
from torch.distributions.mixture_same_family import MixtureSameFamily

from .base_set import BaseSet


def rejection_sampling(n_samples: int, proposal: torch.distributions.Distribution, target_log_prob_fn: Callable, k: float) -> torch.Tensor:
    """Rejection sampling. See Pattern Recognition and ML by Bishop Chapter 11.1"""
    z_0 = proposal.sample((n_samples * 10,))
    u_0 = torch.distributions.Uniform(0, k * torch.exp(proposal.log_prob(z_0))).sample().to(z_0)
    accept = torch.exp(target_log_prob_fn(z_0)) > u_0
    samples = z_0[accept]
    if samples.shape[0] >= n_samples:
        return samples[:n_samples]
    else:
        required_samples = n_samples - samples.shape[0]
        new_samples = rejection_sampling(required_samples, proposal, target_log_prob_fn, k)
        samples = torch.concat([samples, new_samples], dim=0)
        return samples


class DistortedManyWell(BaseSet):
    """
    a1, a2, a3 are random uniform between [1 - C, 1 + C), where C lies in [0, 0.5]
    log p(x1, x2) = −a1 * x1^4 + a2 * 6*x1^2 + a3 * 1/2*x1 − 1/2*x2^2 + constant
    """

    def __init__(self, device, distortion_coef=0.1, dim=32, is_linear=True):
        super().__init__()
        self.device = device

        self.data = torch.ones(dim, dtype=float).to(self.device)
        self.data_ndim = dim

        assert dim % 2 == 0
        self.n_wells = dim // 2

        # as rejection sampling proposal
        self.component_mix = torch.tensor([0.2, 0.8])
        self.means = torch.tensor([-1.7, 1.7])
        self.scales = torch.tensor([0.5, 0.5])

        assert 0 <= distortion_coef and distortion_coef <= 1
        g = torch.Generator()
        g.manual_seed(42)
        print(f"{self.n_wells=}")
        self.coeffs = 1 + distortion_coef * (torch.rand(self.n_wells, 3, generator=g) - 0.5)
        print(f"{self.coeffs=}")

        self.Z_x1 = torch.zeros(self.n_wells)
        integral_space = torch.linspace(-8, 8, 10000)
        delta_x1 = integral_space[1] - integral_space[0]
        for n_well in range(self.n_wells):
            a = self.coeffs[n_well]
            self.Z_x1[n_well] = torch.sum(
                torch.exp((a[2] * 0.5 * integral_space + a[1] * 6 * integral_space.pow(2) - a[0] * integral_space.pow(4))) * delta_x1
            )
        # 11784.50927
        self.logZ_x2 = 0.5 * np.log(2 * np.pi)
        self.logZ_doublewell = np.log(self.Z_x1) + self.logZ_x2

    @property
    def bounds(self):
        return (-4.0, 4.0)

    @property
    def is_many_well(self):
        return True

    @property
    def gt_logz(self):
        return torch.sum(self.logZ_doublewell)

    def energy(self, x):
        energy = -self.manywell_logprob(x)
        assert energy.shape[0] == x.shape[0] and energy.ndim == 1
        return energy

    def doublewell_logprob(self, x, n_well):
        assert x.shape[1] == 2 and x.ndim == 2
        x1 = x[:, 0]
        x2 = x[:, 1]
        a = self.coeffs[n_well]
        x1_term = a[2] * 0.5 * x1 + a[1] * 6 * x1.pow(2) - a[0] * x1.pow(4)
        x2_term = -0.5 * x2.pow(2)
        return x1_term + x2_term

    def manywell_logprob(self, x):
        assert x.ndim == 2
        logprob = torch.stack([self.doublewell_logprob(x[:, i * 2 : i * 2 + 2], i) for i in range(self.n_wells)], dim=1).sum(dim=1)
        return logprob

    def sample_first_dimension(self, batch_size, n_well):
        def target_log_prob(x):
            a = self.coeffs[n_well]
            return -a[0] * (x**4) + a[1] * 6 * x**2 + a[2] * 1 / 2 * x

        # Define proposal
        mix = torch.distributions.Categorical(self.component_mix)
        com = torch.distributions.Normal(self.means, self.scales)
        proposal = torch.distributions.MixtureSameFamily(mixture_distribution=mix, component_distribution=com)

        k = self.Z_x1[n_well] * 3
        samples = rejection_sampling(batch_size, proposal, target_log_prob, k)
        return samples

    def sample_doublewell(self, batch_size, n_well):
        x1 = self.sample_first_dimension(batch_size, n_well)
        x2 = torch.randn_like(x1)
        return torch.stack([x1, x2], dim=1)

    def sample(self, batch_size):
        return torch.cat([self.sample_doublewell(batch_size, n_well) for n_well in range(self.n_wells)], dim=-1)

    def viz_pdf(self, fsave="density.png", lim=3):
        raise NotImplementedError

    def __getitem__(self, idx):
        del idx
        return self.data[0]
