import torch
from torch import nn

# FROM: https://github.com/yiftachbeer/mmd_loss_pytorch/blob/master/mmd_loss.py

class RBF(nn.Module):

    def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None, device='cuda'):
        super().__init__()
        self.device = device
        self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2).to(self.device)
        self.bandwidth = bandwidth

    def get_bandwidth(self, L2_distances):
        if self.bandwidth is None:
            n_samples = L2_distances.shape[0]
            return L2_distances.data.sum() / (n_samples ** 2 - n_samples)

        return self.bandwidth

    def forward(self, X):
        L2_distances = torch.cdist(X, X) ** 2
        return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)


class MMDLoss(nn.Module):

    def __init__(self, device='cuda'):
        super().__init__()
        self.kernel = RBF(device=device)
        self.device = device

    def forward(self, X, Y):
        K = self.kernel(torch.vstack([X, Y]))

        X_size = X.shape[0]
        XX = K[:X_size, :X_size].mean()
        XY = K[:X_size, X_size:].mean()
        YY = K[X_size:, X_size:].mean()
        return XX - 2 * XY + YY
    
    
from pyro.infer import MCMC, NUTS

class SAMPLER():
    def __init__(self, prior, likelihood, dim_z, device='cuda'):
        self.dim_z = dim_z
        self.prior = prior
        self.likelihood = likelihood
        self.device = device
    
        self.max_prop_to_date = 9e9
        self.max_prob_z = None

    def log_posterior(self, param):
        z = param['z']
        
        log_prior = self.prior.log_prob(z)
        log_like = self.likelihood.log_prob(z)
        log_prob = log_prior + log_like
                    
        if log_prob < self.max_prop_to_date:
            print("new map: log_prob")
            self.max_prop_to_date = self.log_prob
            self.max_prob_z = z.clone()
        
        return log_prob

    def run(self):

        kernel = NUTS(potential_fn=self.log_posterior, full_mass=False, jit_compile=False)

        mcmc = MCMC(kernel, 
                initial_params =  {'z':torch.randn(1, self.dim_z).to(self.device)},
                num_samples=100,
                warmup_steps=100,
                num_chains=1,
                # mp_context='spawn'
                )

        mcmc.run()

        samples = mcmc.get_samples()['z']
        samples = samples.reshape(-1, self.dim_z)
        mcmc.summary()  
        samples = samples.cpu().detach().numpy()    
        
        return samples
    
    
    