import torch
from torch.distributions import MultivariateNormal, TransformedDistribution, Uniform
from mcmc_kernels import MCMC_Proposal
from nice_transforms import NICE
from conditioning_networks import ReluMeanScale
import time
from torch.nn.parameter import Parameter

class NICEResample(MCMC_Proposal):
    def __init__(self, d, transform=None, device=torch.device('cpu')):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        if transform is not None:
            self.transform = transform
        else:
            self.transform = NICE(d, alternating_mask=True, num_layers=3, layer_depth=1,layer_width=4, device=self.device)
        self.parameters = torch.nn.ParameterList(self.transform.parameters())
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.conditioner = torch.zeros(d).to(device)
        self.dimension = d

    def rsample(self, x):
        samples = self.base_distribution.rsample(x)
        samples += self.conditioner
        return self.transform(samples)

    def log_prob(self, x):
        y = self.transform._inverse(x)
        return self.base_distribution.log_prob(y - self.conditioner) - self.transform.log_abs_det_jacobian(x, y)
        
    def sample(self, x):
        with torch.no_grad():
            return self.rsample(x)

    def condition(self, x):
        self.conditioner = torch.zeros(x.shape).to(self.device)

class NICERelu(MCMC_Proposal):
    def __init__(self, d, mean_params, scale_params, transform=None, device=torch.device('cpu'), rw_bias=0.0):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        if transform is not None:
            self.transform = transform
        else:
            self.transform = NICE(d, alternating_mask=True, num_layers=3, layer_depth=1,layer_width=4, device=self.device)
        self.conditioning_network = ReluMeanScale(d, mean_params, scale_params, device=self.device)
        self.rw_bias = rw_bias*torch.ones(1).to(device)
        self.parameters = torch.nn.ParameterList([param for param in self.transform.parameters()] + [param for param in self.conditioning_network.parameters()])
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.conditioner = torch.zeros(d).to(device)
        self.dimension = d

    def rsample(self, x):
        mean, scale = self.conditioning_network()        
        noise = self.base_distribution.rsample([mean.shape[0]] + x)
        samples = scale.unsqueeze(1)*(noise + mean.unsqueeze(1)) + self.rw_bias*self.conditioner.unsqueeze(1)

        #samples *= scale
        #samples += mean

        return self.transform(samples.view(samples.shape[0]*samples.shape[1], samples.shape[2]))

    def log_prob(self, x):
        y = self.transform._inverse(x)
        mean, scale = self.conditioning_network()
        return self.base_distribution.log_prob((y - self.rw_bias*self.conditioner )/(scale)   - mean   ) - torch.log(scale).sum(dim=-1) - self.transform.log_abs_det_jacobian(x, y)
        
    def sample(self, x):
        with torch.no_grad():
            return self.rsample(x)

    def condition(self, x):
        if len(x.shape) ==1:
            self.conditioner = self.transform._inverse(x.unsqueeze(0))
        else:
            self.conditioner = self.transform._inverse(x)
        self.conditioning_network.condition(self.conditioner)

class NICEDoubleRelu(MCMC_Proposal):
    def __init__(self, d, mean_params, scale_params, transform=None, device=torch.device('cpu'), rwm_bias = 0.0):
        self.device=device
        self.base_distribution = MultivariateNormal(torch.zeros(d).to(self.device), torch.eye(d).to(self.device))
        if transform is not None:
            self.transform = transform
        else:
            self.transform = NICE(d, alternating_mask=True, num_layers=3, layer_depth=1,layer_width=4, device=self.device)
        self.latent_conditioning_network = ReluMeanScale(d, mean_params, scale_params, device=self.device)
        self.data_conditioning_network = ReluMeanScale(d, mean_params, scale_params, device=self.device)
        self.parameters = torch.nn.ParameterList([param for param in self.transform.parameters()] + [param for param in self.latent_conditioning_network.parameters()] + [param for param in self.data_conditioning_network.parameters()])
        self.distribution = TransformedDistribution(self.base_distribution, [self.transform])
        self.latent_conditioner = torch.zeros(d).to(device)
        self.data_conditioner = torch.zeros(d).to(device)
        self.rwm_bias = rwm_bias * torch.ones(1).to(device)
        self.dimension = d

    def rsample(self, x):
        latent_mean, latent_scale = self.latent_conditioning_network()
        data_mean, data_scale = self.data_conditioning_network()
        data_mean = data_mean + self.data_conditioner * self.rwm_bias
        noise = self.base_distribution.rsample([latent_mean.shape[0]] + x)
        samples = latent_scale.unsqueeze(1)*(noise) + 0.1*latent_mean.unsqueeze(1)

        #samples *= scale
        #samples += mean

        return (data_scale.unsqueeze(1)*self.transform(samples) + 0.1*data_mean.unsqueeze(1)).view(samples.shape[0]*samples.shape[1], samples.shape[2])

    def log_prob(self, x):
        data_mean, data_scale = self.data_conditioning_network()
        data_mean = data_mean + self.data_conditioner * self.rwm_bias
        latent_mean, latent_scale = self.latent_conditioning_network()
        y = self.transform._inverse((x - 0.1*data_mean)/data_scale)
        return self.base_distribution.log_prob( (y- 0.1*latent_mean) /latent_scale ) - torch.log(latent_scale).sum(dim=-1) - torch.log(data_scale).sum(dim=-1) - self.transform.log_abs_det_jacobian(x, y)
        
    def sample(self, x):
        with torch.no_grad():
            return self.rsample(x)

    def condition(self, x):
        if len(x.shape) ==1:
            self.latent_conditioner = self.transform._inverse(x.unsqueeze(0))
            self.data_conditioner = x.unsqueeze(0)
        else:
            self.latent_conditioner = self.transform._inverse(x)
            self.data_conditioner = x
        self.latent_conditioning_network.condition(self.latent_conditioner)
        self.data_conditioning_network.condition(self.data_conditioner)


