import torch
import numpy as np
from typing import Callable, Tuple
from tqdm import tqdm
from .HMC import HMCSampler
import os, sys
_tensor = torch.Tensor


def _get_def_sampler():
    sampler = HMCSampler()
    return lambda pt, nll_fcn: sampler.sample(pt, nll_fcn, 10)


def sigmoid_schedule(T: int, rad: float=4) -> Callable:
    """
    Sigmoid schedule, defined as:
            g_t = sigmoid( rad * (2t/T) )
            beta_t = (g_t - g_1)/(g_T - g_1)
    Using this scheduler allocates more time near t=0 and t=T, which is good for AIS as those are the limits when
    the distribution changes the most rapidly.
    :param T: total number of iterations for the scheduler.
    :param rad: a float controlling the amount of time spent in flatter areas of the sigmoid function.
    :return: a function that receives an index 0<=i<T and returns the scheduled value for that time point.
    """
    t = np.linspace(-rad, rad, T+1)
    s = 1.0 / (1.0 + np.exp(-t))
    s = (s - np.min(s)) / (np.max(s) - np.min(s))
    return lambda i: s[i] if i < len(s) else 1


def AIS(points: _tensor, nll_func: Callable, sampler: Callable=None, T: int=5000, n_chains: int=16,
        verbose: bool=True, initial_log_norm: float=0) -> tuple[_tensor, _tensor]:
    """
    Implementation of the Annealed Importance Sampling algorithm introduced by Radford M. Neal (1998) in
    https://arxiv.org/abs/physics/9803008 .

    :param points: initial position of points for the sampling algorithm; a tensor with shape [N, ...].

    :param nll_func: a pytorch-differentiable function that receives as input a torch tensor with shape
                     [M, ...] as well as the iteration i and returns the negative log-likelihood (or energy) of the i-th
                     distribution in the sequence for each of the M points; a tensor of shape [M,].
                     The expected API of this function is:
                                        NLL_i = nll_func(points, i)
                     i=T should return the negative log-likelihood for the target distribution while i=0 should return
                     the same for the base distribution.
                     **Important note**: in order to add chains, the input to the nll_func will be a tensor of shape
                     [N, n_chains, ...] reshaped into [N*n_chains, ...] - this should be taken into
                     account if the nll_func is data-point dependent.

    :param sampler: the transition kernel used in each step of the AIS algorithm; this should be a Callable function
                    that receives as input a tensor of shape [M, ...] and the Callable nll_func and returns a new set
                    of points as a tensor of shape [M, ...]. The expected API of the function is:
                                        points_i = sampler( points_old, lambda x: nll_func(x, i) )
                    If no sampler is provided, a Hamiltonian Monte Carlo sampler with 10 leapfrogs for each transition
                    and a step size of 0.01 will be used (see https://github.com/friedmanroy/HMC-AIS/blob/main/HMC.py ).

    :param T: number of steps to run the algorithm; the nll_func should be defined for every iteration in range(T)

    :param n_chains: number of chains to run concurrently - the higher the number, the better the approximation of the
                     normalization constant of nll_func(points, T)

    :param verbose: a boolean indicating whether to print a progress bar or not

    :param initial_log_norm: the log of the normalization constant of the first distribution from the given sequence of
                             distributions; if not given, assumed to be 0 (i.e. the normalization constant is 1).

    :return: a tuple of the sampled points as a tensor with shape [N, n_chains, ...] and the log-weights for each of the
             samples, as a tensor of shape [N, n_chains].
    """
    dev = points.device
    N = points.shape[0]
    points = points[:, None] * torch.ones(points.shape[0], n_chains, *points.shape[1:], device=dev)
    points = points.reshape(points.shape[0]*n_chains, -1)

    # initialize weights
    log_weights = torch.ones(N*n_chains, device=dev)*initial_log_norm

    pbar = tqdm(range(T), disable=not verbose)
    for i in pbar:
        # update the log-weights according to the previous sample
        with torch.no_grad():
            new_ll = - nll_func(points, i+1)
            old_ll = - nll_func(points, i)
            log_weights = log_weights + new_ll - old_ll
        pbar.set_postfix_str(f'avg. log-weight: {torch.mean(log_weights).item():.2f}; '
                             f'avg. ll: {torch.mean(new_ll).item():.2f}')
        points.data = sampler(points, lambda x: nll_func(x, i+1))

    return points.reshape(N, n_chains, -1), log_weights.reshape(N, n_chains)


def AIS_likelihood(points: _tensor, data: _tensor, nll_fcn: Callable, sampler: Callable=None, schedule: Callable=None,
                   steps: int=1000, n_chains: int=16, verbose: bool=True) -> Tuple[_tensor, _tensor]:
    """
    Use AIS in order to get a lower bound on the log-likelihood of data points under a generative model; the algorithm
    is the one presented by Wu et al. (2017) in https://arxiv.org/pdf/1611.04273v2.pdf . This function assumes that the
    prior of the generative model is Gaussian, i.e. N(0, I) and the intermediate distributions in equation 5 of
    Wu el al.

    :param points: initial points for the sampling algorithm; a tensor with shape [N, dim]

    :param data: the data whose likelihood needs to be calculated; a tensor with shape [N, ...]

    :param nll_fcn: a pytorch-differentiable function that calculates the negative conditional log-likelihood:
                                            nll_func(z, x) = - log p(x|z)
                    up to the normalization constant.

    :param sampler: the transition kernel used in each step of the AIS algorithm; full specification in the API for AIS

    :param schedule: the schedule for beta (see equation 5 in Wu et al.) where:
                                            0 <= beta(t) <= 1
                      as a Callable; the function should expect 0 <= i <= T as input. If no schedule is given, defaults
                      to a sigmoid schedule, as described in Wu el al.

    :param steps: number of steps to run the algorithm - larger number of steps give tighter the lower bounds

    :param n_chains: number of chains to run concurrently - more chains give better approximations of the lower bound

    :param verbose: a boolean indicating whether to print a progress bar or not

    :return: a tuple of the sampled points as a tensor with shape [N, n_chains, dim] and the log weights
             of each of the data points in each chain as a tensor with shape [N, n_chains]
    """
    device = data.device
    N = data.shape[0]

    if sampler is None: sampler = _get_def_sampler()
    if schedule is None: schedule = sigmoid_schedule(steps)

    # multiply data to have multiple chains
    data = data[:, None]*torch.ones(N, n_chains, *data.shape[1:], device=device)
    data = data.reshape(N * n_chains, *data.shape[2:])

    # multiply initial positions to sample multiple chains
    points = points[:, None] * torch.ones(N, n_chains, *points.shape[1:], device=device)
    points = torch.randn_like(points.reshape(N * n_chains, *points.shape[2:]))

    # define nll function per iteration
    loss_fn = lambda x, i: schedule(i)*nll_fcn(x, data) + .5*torch.sum((x*x).reshape(x.shape[0], -1), dim=1)
    init_norm = 0

    # run AIS
    points, weights = AIS(points, nll_func=loss_fn, sampler=sampler, T=steps, n_chains=1, verbose=verbose,
                          initial_log_norm=init_norm)

    # reshape points and weights
    points = points[:, 0].reshape(N, n_chains, -1)
    weights = weights[:, 0].reshape(N, n_chains)
    return points, weights


if __name__ == '__main__':
    # the following is an example of how the above code should be used to calculate the AIS log-likelihood of a GAN
    GAN = generator_func  # a function that gets as input the latent code and label of the GAN and generates an image
    x = data[:10]  # the data used to calculate the posterior, here we use only the first 10 samples
    gamma = .1  # the inverse variance of the observation model
    label = 0  # label to use during log-likelihood calculation
    dim = 64*64*3  # dimension of the data, for instance 64 x 64 x 3 images
    latent = 120  # latent dimension of the GAN, for instance 120

    def nll_fcn(z, x):
        # definition of the negative log-likelihood of the target distribution, assuming a Gaussian observation model
        norm = .5*dim*np.log(2*np.pi/gamma)
        return torch.sum((x-generator_func(z, label))**2, dim=(1, 2, 3)) + norm

    lr = 5e-2  # intial learning rate of the sampler
    n_steps = 10  # number of leapfrogs per AIS step
    sampler = HMCSampler(init_step_sz=lr, min_step=1e-4, max_step=.5)
    sample_fcn = lambda x, nll: sampler.sample(x, nll, n_steps=n_steps)

    # run AIS on randomly initialized points
    zs_init = torch.randn(x.shape[0], latent, device=x.device)
    zs, weights = AIS_likelihood(zs_init, x, nll_fcn, sampler=sample_fcn, steps=500, n_chains=1, verbose=1)
