import torch
from torch.distributions import MultivariateNormal, TransformedDistribution, Uniform
from invert_linear import CondInvLinear, CondScale
from mixture_weight_networks import ConstantMixtureWeights

class MCMC_Proposal():
    def __init__(self, device=torch.device('cpu')):
        self.device=device
        self.distribution = None
        self.transform = None
        self.parameters = torch.nn.ParameterList()
        self.dimension=None

    def log_prob(self, x):
        return self.distribution.log_prob(x)

    def condition(self, x):
        self.transform.condition(x)

    def rsample(self, x):
        return self.distribution.rsample(x)

    def sample(self, x):
        return self.distribution.sample(x)

class Proposal_Mixture(MCMC_Proposal):
    def __init__(self, components, weights=None, device=torch.device('cpu')):
        self.device=device
        self.num_components = len(components)
        self.components = components
        parameters = set()
        for component in components:
            parameters |= set(component.parameters)
        self.component_parameters =torch.nn.ParameterList(parameters)
        if weights is None:
            self.mixture_weights = ConstantMixtureWeights(self.num_components, device=self.device)
            self.mixture_parameters = self.mixture_weights.parameters()

        else:
            self.mixture_weights = weights
            self.mixture_parameters = weights.parameters()
        self.dimension = self.components[0].dimension
        

    def log_prob(self, x):
        log_probs = self.components[0].log_prob(x).unsqueeze(-1)
        for component in self.components[1:]:
            log_probs = torch.cat((log_probs, component.log_prob(x).unsqueeze(-1)), -1)
        return torch.logsumexp(log_probs + self.mixture_weights(), -1) - torch.logsumexp(self.mixture_weights(), -1)

    def condition(self, x):
        self.mixture_weights.condition(x)
        for component in self.components:
            component.condition(x)

    def sample(self, x):
        mixture_dist = torch.distributions.categorical.Categorical(torch.exp(self.mixture_weights() - torch.logsumexp(self.mixture_weights(), 0)))
        sample_indices = mixture_dist.sample(x)
        samples = self.components[sample_indices[0]].sample([1]).unsqueeze(-1)
        for i in range(1,len(sample_indices)):
            samples = torch.cat((samples, self.components[sample_indices[i]].sample([1]).unsqueeze(-1)), 0)
        return samples.squeeze()

    def rsample(self, x):
        mixture_dist = torch.distributions.categorical.Categorical(torch.exp(self.mixture_weights() - torch.logsumexp(self.mixture_weights(), 0)))
        sample_indices = mixture_dist.sample(x)
        samples = self.components[sample_indices[0]].sample([1]).unsqueeze(-1)
        for i in range(1,len(sample_indices)):
            samples = torch.cat((samples, self.components[sample_indices[i]].rsample([1]).unsqueeze(-1)), 0)
        return samples.squeeze()

    def uniform_sample(self, x):
        samples = self.components[0].sample(x).view(x[0], self.dimension)
        log_weights = self.mixture_weights()[0].detach()*torch.ones(x, device=self.device) 
        for i in range(1, len(self.components)):
            samples = torch.cat((samples, self.components[i].sample(x).view(x[0], self.dimension)), 0)
            log_weights = torch.cat((log_weights, self.mixture_weights()[i].detach()*torch.ones(x, device=self.device)), 0)
        return samples, log_weights - torch.logsumexp(self.mixture_weights(), 0).detach()

    def uniform_rsample(self, x):
        samples = self.components[0].rsample(x).view(x[0], self.dimension)
        log_weights = self.mixture_weights()[0]*torch.ones(x, device=self.device)
        for i in range(1, len(self.components)):
            samples = torch.cat((samples, self.components[i].rsample(x).view(x[0], self.dimension)), 0)
            log_weights = torch.cat((log_weights, self.mixture_weights()[i]*torch.ones(x, device=self.device)), 0)
        return samples, log_weights - torch.logsumexp(self.mixture_weights(), 0)   
        

class IsotropicUniform(MCMC_Proposal):
    def __init__(self, d, bias=False, device=torch.device('cpu')):
        self.device=device
        self.base_distribution = Uniform(torch.zeros(d).to(self.device), torch.ones(d).to(self.device), validate_args=False)
        self.transform = CondScale(d, bias=bias, device=self.device)
        with torch.no_grad():
            self.transform.weight.data /= d**0.5
            if bias:
                self.transform.bias.data /= d
        self.parameters = torch.nn.ParameterList(self.transform.parameters())
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.dimension = d

    def log_prob(self, x):
        return self.distribution.log_prob(x).sum(-1)

class IsotropicRWM(MCMC_Proposal):
    def __init__(self, d, bias=False, device=torch.device('cpu')):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        self.transform = CondScale(d, bias=bias, device=self.device)
        
        self.parameters = torch.nn.ParameterList(self.transform.parameters())
        with torch.no_grad():
            self.transform.weight.data /= d**0.5
            if bias:
                self.transform.bias.data /= d
        
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.dimension = d

class IsotropicMALA(MCMC_Proposal):
    def __init__(self, d, target, device=torch.device('cpu')):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        self.transform = CondScale(d, bias=False,device=self.device)
        with torch.no_grad():
            self.transform.weight.data /= d**0.15
        self.parameters = torch.nn.ParameterList(self.transform.parameters())
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.target = target
        self.dimension = d
        
    def condition(self, x):
      
        start = x
        
        target_prob = self.target.log_prob(start)
        grad, = torch.autograd.grad(target_prob, start, grad_outputs=target_prob.data.new(target_prob.shape).fill_(1), create_graph=True)
        return self.transform.condition(x + grad*((self.transform.weight).abs()**2)/2.0)

class IsotropicMALAGaussian(MCMC_Proposal):
    def __init__(self, d, target, device=torch.device('cpu')):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        self.transform = CondScale(d, bias=False,device=self.device)
        with torch.no_grad():
            self.transform.weight.data /= d**0.15
        self.parameters = torch.nn.ParameterList(self.transform.parameters())
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.target = target
        self.dimension = d
        
    def condition(self, x):
      
        start = x
        grad = -x
        return self.transform.condition(x + grad*((self.transform.weight).abs()**2)/2.0)

class FullRWM(MCMC_Proposal):
    def __init__(self, d, bias=False, device=torch.device('cpu')):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        self.transform = CondInvLinear(d, bias=bias, device=self.device)
        with torch.no_grad():
            self.transform.weight.data /= d**0.5
            if bias:
                self.transform.bias.data /= d
        self.parameters = torch.nn.ParameterList(self.transform.parameters())
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.dimension = d

class FullRWMFixed(MCMC_Proposal):
    def __init__(self, d, bias=False, device=torch.device('cpu')):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        self.transform = CondInvLinear(d, bias=bias, device=self.device)
        with torch.no_grad():
            self.transform.weight.data /= d**0.5
            if bias:
                self.transform.bias.data /= d
        self.parameters = torch.nn.ParameterList(self.transform.parameters())
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.dimension = d

    def condition(self, x):
        pass

class IsotropicRWMFixed(MCMC_Proposal):
    def __init__(self, d, bias=False, device=torch.device('cpu')):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        self.transform = CondScale(d, bias=bias, device=self.device)
        with torch.no_grad():
            self.transform.weight.data /= d**0.5
            if bias:
                self.transform.bias.data /= d
        self.parameters = torch.nn.ParameterList(self.transform.parameters())
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.dimension = d

    def condition(self, x):
        pass
