import math
import time

import torch

from code.realnvp_v2 import RealNVP
from code.exp_utils import normal_logp, diag_normal_logp, ForwardOp


class TransitionKernel:
    def __init__(self, type_: str, sigma_p: float, sigma_r: float=None):
        assert type_ in ('perturb', 'perturb_or_resample')

        if type_ == 'perturb_or_resample':
            raise NotImplementedError
            assert sigma_r is not None

        self.type_ = type_
        self.sigma_p = sigma_p
        self.sigma_r = sigma_r

    @torch.no_grad()
    def sample(self, z: torch.Tensor):
        assert z.ndim == 4
        if self.type_ == 'perturb':
            zp = z + torch.randn_like(z) * self.sigma_p
        elif self.type_ == 'perturb_or_resample':
            raise NotImplementedError
        assert zp.shape == z.shape and zp.dtype == z.dtype
        return zp

    @torch.no_grad()
    def log_prob(self, zp: torch.Tensor, z: torch.Tensor, reverse=False):
        # reverse=False means g(z' | z)
        # reverse=True means g(z | z')
        assert zp.shape == z.shape
        assert zp.ndim == z.ndim == 4
        if self.type_ == 'perturb':
            return normal_logp((zp - z) / self.sigma_p)


class AuxDensity:
    def __init__(self, x_obs: torch.Tensor, sigma_a: float):
        assert x_obs.ndim == 4
        self.x_obs = x_obs
        self.sigma_a = sigma_a

    def log_prob(self, y_obs: torch.Tensor):
        assert y_obs.shape == self.x_obs.shape
        assert y_obs.ndim == 4
        return normal_logp((y_obs - self.x_obs) / self.sigma_a)


def plmcmc_step(base_model: RealNVP,
                x_obs: torch.Tensor,
                z: torch.Tensor,
                forward_op: ForwardOp,
                g: TransitionKernel,
                q: AuxDensity):
    # Sample new proposal
    zp = g.sample(z)
    y, logdet_z = base_model(z, inverse=True)
    y_obs, y_mis = forward_op.observe(y.clamp(0, 1))
    yp, logdet_zp = base_model(zp, inverse=True)
    yp_obs, yp_mis = forward_op.observe(yp.clamp(0, 1))

    assert y_obs.min() >= 0.0 and y_obs.max() <= 1.0 and y_mis.min() >= 0.0 and y_mis.max() <= 1.0
    assert yp_obs.min() >= 0.0 and yp_obs.max() <= 1.0 and yp_mis.min() >= 0.0 and yp_mis.max() <= 1.0
    assert y_obs.shape == yp_obs.shape == x_obs.shape

    # log-numerator
    log_alpha = (q.log_prob(yp_obs) + base_model.log_prob(forward_op.combine(x_obs, yp_mis))[0]
                 + g.log_prob(zp, z, reverse=True) - logdet_zp)

    # log-denominator
    log_alpha = log_alpha - (
            q.log_prob(y_obs) + base_model.log_prob(forward_op.combine(x_obs, y_mis))[0]
            + g.log_prob(zp, z) - logdet_z)

    # Metropolis-Hastings step
    alpha = torch.min(torch.ones_like(log_alpha), log_alpha.exp())
    noise = torch.rand_like(alpha)
    assert alpha.shape == noise.shape == (len(x_obs),)
    mask = alpha > noise
    z[mask] = zp[mask]
    y_obs, y_mis, yp_obs, yp_mis = None, None, None, None

    return z


@torch.no_grad()
def run_plmcmc_singlesample(*,
                            base_model: RealNVP,
                            x_obs: torch.Tensor,
                            z_init: torch.Tensor,
                            forward_op: ForwardOp,
                            g: TransitionKernel,
                            q: AuxDensity,
                            num_steps: int,
                            sample_steps):
    assert num_steps > 0 and len(sample_steps) > 0
    z = z_init.clone()

    samples = {}
    start_time = time.time()
    for i in range(1, num_steps + 1):
        z = plmcmc_step(base_model, x_obs, z, forward_op, g, q)
        print(f'\rstep {i} / {num_steps} ...', end='', flush=True)

        # Monitoring
        etime = time.time() - start_time
        print(f'\rstep {i} / {num_steps} '
              f'({etime:.2f} sec, {i/etime:.2f} steps/s)', end='', flush=True)
        if i in sample_steps or i == num_steps:
            x = base_model(z, inverse=True)[0]
            samples[i] = x.detach().clone().cpu()
            print(f' --> sampled at step {i}')

    print(' finished', flush=True)

    return samples


@torch.enable_grad()
def run_sgld(*,
             base_model: RealNVP,
             x_star: torch.Tensor,
             x_init: torch.Tensor,
             forward_op: ForwardOp,
             sigma: float,
             learning_rate: float,
             num_steps: int,
             log_every: int=None):
    if log_every is None:
        x_list = None
    else:
        x_list = [x_init.clone()]

    lr = learning_rate
    x_obs, _ = forward_op.observe(x_star)
    x = x_init.clone().requires_grad_(True)
    xcur_obs, _ = forward_op.observe(x)

    for step in range(1, num_steps+1):
        # Compute score function
        log_px = base_model.log_prob(x)[0]
        log_smooth = diag_normal_logp(xcur_obs - x_obs, sigma)
        assert log_px.shape == log_smooth.shape == (len(x),)
        log_cond = (log_px + log_smooth).split(1, dim=0)
        score = torch.autograd.grad(log_cond, x)[0]
        assert score.shape == x.shape

        # Perform one step of Langevin dynamics
        dx = 0.5 * lr * score + math.sqrt(lr) * torch.randn_like(x)
        x = (x + dx).clamp(0, 1)
        x_cur_obs, _ = forward_op.observe(x)

        if log_every is not None and (step % log_every == 0 or step == num_steps):
            x_list.append(x.detach().clone())

    return x.detach(), x_list


@torch.enable_grad()
def run_sgld_in_z(*,
                  base_model: RealNVP,
                  x_star: torch.Tensor,
                  z_init: torch.Tensor,
                  forward_op: ForwardOp,
                  sigma: float,
                  learning_rate: float,
                  num_steps: int,
                  sample_steps):
    assert x_star.ndim == 4
    lr = learning_rate

    x_star_obs, _ = forward_op.observe(x_star)
    z = z_init.clone().requires_grad_(True)
    x_obs, _ = forward_op.observe(base_model(z, inverse=True)[0])
    samples = {}
    start_time = time.time()

    for i in range(1, num_steps+1):
        # Compute score function
        log_pz = base_model.log_prior(z)
        log_smooth = diag_normal_logp(x_obs - x_star_obs, sigma)
        assert log_pz.shape == log_smooth.shape == (len(z),)
        log_cond = (log_pz + log_smooth).split(1, dim=0)
        score = torch.autograd.grad(log_cond, z)[0]
        assert score.shape == z.shape

        # Perform one step of Langevin dynamics
        dz = 0.5 * lr * score + math.sqrt(lr) * torch.randn_like(z)
        z = z + dz
        x = base_model(z, inverse=True)[0]
        x_obs, _ = forward_op.observe(x)

        # Monitoring
        print(f'\rstep {i} / {num_steps} ({time.time()-start_time:.2f} sec)', end='', flush=True)
        if i in sample_steps or i == num_steps:
            samples[i] = x.detach().clone().cpu()
            print(f' --> sampled at step {i}')

    print(' finished', flush=True)

    return samples

