import torch
import numpy as np

def MCMCStep(dist, target, init):
    start = init.detach().requires_grad_()
    dist.condition(start)
    samples = dist.rsample([2])
    factor = (target.log_prob(samples) - target.log_prob(start)).detach()
    forward_prob = dist.log_prob(samples).detach()
    dist.condition(samples)
    reverse_prob = dist.log_prob(start).detach()
    acc_rate = torch.exp(torch.clamp(factor - forward_prob + reverse_prob, max=0)).detach()
    samples[1][0].uniform_()
    return (start * (samples[1][0] >= acc_rate[0]) + samples[0] * (samples[1][0] < acc_rate[0])).detach(), acc_rate[0].detach()

def MCMCChain(dist, target, start, num_samples):
    chain = np.zeros([num_samples,len(start)])
    acceptances = np.zeros([num_samples-1])
    chain[0] = start.cpu().detach().numpy()
    current = start.clone()
    for i in range(1, num_samples):
        current, accepted = MCMCStep(dist, target, current)
        chain[i]  = current.cpu().detach().numpy()
        acceptances[i-1] = accepted.cpu().detach().numpy()
    return chain, np.mean(acceptances)

def AuxMCMCChain(dist, target, start,num_aux, num_samples):
    chain = np.zeros([num_samples,len(start)])
    acceptances = np.zeros([num_samples-1])
    chain[0] = start.cpu().detach().numpy()
    current = start.clone()
    for i in range(1, num_samples):
        if num_aux >0:
            current[-num_aux:].normal_()
        current, accepted = MCMCStep(dist, target, current)
        chain[i]  = current.cpu().detach().numpy()
        acceptances[i-1] = accepted.cpu().detach().numpy()
    return chain, np.mean(acceptances)

def GibbsStep(target, init,index, scale=1.0):
    start = init.detach().requires_grad_()
    samples = start.clone().unsqueeze(0)
    samples = samples.repeat(2, 1)
    samples[1][0].normal_()
    samples[1][0] *= scale
    samples[0][index] += samples[1][0]
    factor = (target.log_prob(samples) - target.log_prob(start)).detach()
    acc_rate = torch.exp(torch.clamp(factor, max=0)).detach()
    samples[1][0].uniform_()
    return (start * (samples[1][0] >= acc_rate[0]) + samples[0] * (samples[1][0] < acc_rate[0])).detach(), acc_rate[0].detach()

def GibbsChain(target, start, num_samples, scale=1.0):
    chain = np.zeros([num_samples,len(start)])
    acceptances = np.zeros([num_samples-1])
    chain[0] = start.cpu().detach().numpy()
    current = start.clone()
    for i in range(1, num_samples):
        print(i)
        current, accepted = GibbsStep(target, current,i % len(start), scale=scale)
        chain[i]  = current.cpu().detach().numpy()
        acceptances[i-1] = accepted.cpu().detach().numpy()
    return chain, np.mean(acceptances)
