from .priors import *
from .base_net import *

import torch.nn.functional as F
import torch.nn as nn
import copy
import math


def sample_gaussian(self, mean, std_rho):
    epsilon = mean.data.new(mean.size()).normal_(0, 1)
    std_dev = 1e-6 + F.softplus(std_rho, beta=1, threshold=20)
    return self.mean + self.std_dev * epsilon

def sample_weights(W_mu, b_mu, W_p, b_p):
    """Quick method for sampling weights and exporting weights"""
    eps_W = W_mu.data.new(W_mu.size()).normal_()
    # sample parameters
    std_w = 1e-6 + F.softplus(W_p, beta=1, threshold=20)
    W = W_mu + std_w * eps_W

    if b_mu is not None:
        std_b = 1e-6 + F.softplus(b_p, beta=1, threshold=20)
        eps_b = b_mu.data.new(b_mu.size()).normal_()
        b = b_mu + std_b * eps_b
    else: b = None

    return W, b

def KLD_cost(mu_p, sig_p, mu_q, sig_q):
    """ Compute KL(q||p)
    """
    KLD = 0.5 * (2 * torch.log(sig_p / sig_q) - 1 + (sig_q / sig_p).pow(2) + ((mu_p - mu_q) / sig_p).pow(2)).sum()
    # https://arxiv.org/abs/1312.6114 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    return KLD


class GaussianLinear(nn.Module):
    """Linear Layer where weights are sampled from a fully factorised Normal with learnable parameters. The likelihood
     of the weight samples under the prior and the approximate posterior are returned with each forward pass in order
     to estimate the KL term in the ELBO.
    """
    def __init__(self, n_in, n_out, prior_class, prior_bias=None):
        super(GaussianLinear, self).__init__()
        self.n_in = n_in
        self.n_out = n_out
        self.prior = prior_class
        self.prior_bias = prior_bias
    
        # Learnable parameters -> Initialisation is set empirically.
        self.W_mu = nn.Parameter(torch.Tensor(self.n_in, self.n_out).uniform_(-0.1, 0.1))
        self.W_p = nn.Parameter(torch.Tensor(self.n_in, self.n_out).uniform_(-3, -2))

        self.b_mu = nn.Parameter(torch.Tensor(self.n_out).uniform_(-0.1, 0.1))
        self.b_p = nn.Parameter(torch.Tensor(self.n_out).uniform_(-3, -2))

        self.eps_W = None
        self.eps_b = None

    def __repr__(self):
        return "GaussianLinear(n_in={}, n_out={}, w_prior={}, b_prior={})".format(
                    self.n_in, self.n_out, str(self.prior), str(self.prior_bias))

    def forward(self, X, sample=False, previous=False):
        if not self.training and not sample: 
            output = torch.mm(X, self.W_mu) + self.b_mu.expand(X.size()[0], self.n_out)
            return output
        else:
            # local reparameterize sample 
            std_w = 1e-6 + F.softplus(self.W_p, beta=1, threshold=20)
            std_b = 1e-6 + F.softplus(self.b_p, beta=1, threshold=20)

            act_W_mu = torch.mm(X, self.W_mu)  # act_mu = self.W_mu * X
            act_W_std = torch.sqrt(torch.mm(X.pow(2), std_w.pow(2))) # act_sigma = self.W_mu^2 * X^2

            if not previous:
                self.eps_W = Variable(self.W_mu.data.new(act_W_std.size()).normal_(mean=0, std=1))
                self.eps_b = Variable(self.b_mu.data.new(std_b.size()).normal_(mean=0, std=1))

            act_W_out = act_W_mu + act_W_std * self.eps_W  # (batch_size, n_output)
            act_b_out = self.b_mu + std_b * self.eps_b

            output = act_W_out + act_b_out.unsqueeze(0).expand(X.shape[0], -1)  # (batch_size, n_output)
            return output

    def forward_eval(self, X, sample=True, previous=False):
        '''
        Used for evaluation phase when we sample weights for reuse multiple times,
        rather than local reparameterization trick (sample Z= WA + B)
        '''
        if not sample:
            output = torch.mm(X, self.W_mu) + self.b_mu.expand(X.size()[0], self.n_out)
            return output

        std_w = 1e-6 + F.softplus(self.W_p, beta=1, threshold=20)
        std_b = 1e-6 + F.softplus(self.b_p, beta=1, threshold=20)

        if not previous:
            self.eps_W = Variable(self.W_mu.data.new(self.W_mu.size()).normal_(mean=0, std=1))
            self.eps_b = Variable(self.b_mu.data.new(self.b_mu.size()).normal_(mean=0, std=1))

        W = self.W_mu + std_w * self.eps_W
        b = self.b_mu + std_b * self.eps_b

        output = torch.mm(X, W) + b.unsqueeze(0).expand(X.shape[0], -1) # (batch_size, n_output)
        return output

    
    def kl_divergence(self):
        std_w = 1e-6 + F.softplus(self.W_p, beta=1, threshold=20)
        std_b = 1e-6 + F.softplus(self.b_p, beta=1, threshold=20)
        kld = KLD_cost(mu_p=0, sig_p=self.prior.sigma, mu_q=self.W_mu, sig_q=std_w) \
            + KLD_cost(mu_p=0, sig_p=self.prior_bias.sigma, mu_q=self.b_mu, sig_q=std_b)
        return kld

    def kl_div_mc(self):
        '''One step Monte Carlor estimate for KL divergence
        kldiv(W) = q(W|\theta) - p(W)
        kldiv(b) = q(b|\theta) - p(b)
        '''
        eps_W = Variable(self.W_mu.data.new(self.W_mu.size()).normal_())
        eps_b = Variable(self.b_mu.data.new(self.b_mu.size()).normal_())

        # sample parameters
        std_w = 1e-6 + F.softplus(self.W_p, beta=1, threshold=20)
        std_b = 1e-6 + F.softplus(self.b_p, beta=1, threshold=20)

        W = self.W_mu + 1. * std_w * eps_W
        b = self.b_mu + 1. * std_b * eps_b

        lqw = isotropic_gauss_loglike(W, self.W_mu, std_w) + isotropic_gauss_loglike(b, self.b_mu, std_b)
        lpw = self.prior.loglike(W) + self.prior_bias.loglike(b)

        return lqw - lpw

class FCBNN(nn.Module):
    '''
    Fully connected architectural Bayesian neural network
    '''
    def __init__(self, input_dim, output_dim, n_hid, prior_instance, bias_prior_instance, num_class, use_pred_loss, activation_func='tanh'):
        super(FCBNN, self).__init__()

        # prior_instance = spike_slab_2GMM(mu1=0, mu2=0, sigma1=0.135, sigma2=0.001, pi=0.5)
        
        self.prior_instance = prior_instance
        self.num_class = num_class
        self.use_pred_loss = use_pred_loss
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.layers = []
        self.act_func = nn.Tanh()
      
        layer_in_size = input_dim
        for idx, hs in enumerate(n_hid):
            layer_out_size = hs
            layer = GaussianLinear(layer_in_size, layer_out_size, prior_instance, bias_prior_instance)
            self.layers.append(layer)
            layer_in_size = hs
            self.layers.append(self.act_func)

        last_layer = GaussianLinear(layer_in_size, output_dim, prior_instance, bias_prior_instance)
        self.layers.append(last_layer)
        self.layers = nn.Sequential(*self.layers)

        if self.use_pred_loss:
            self.fc_pred = GaussianLinear(output_dim, num_class, prior_instance, bias_prior_instance)

    def forward(self, x, sample=False, previous=False):
        y = x.view(-1, self.input_dim)  # view(batch_size, input_dim)
        for idx, layer in enumerate(self.layers):
            if isinstance(layer, GaussianLinear): 
                y = layer(y, sample, previous)
            else: y = layer(y)

        # y = y / y.pow(2).sum(1, keepdim=True).sqrt()
        
        logits = self.act_func(y)
        if self.use_pred_loss:
            logits = self.fc_pred(logits, sample, previous)
        return y, logits

    def forward_eval(self, x, sample=False, previous=False):
        y = x.view(-1, self.input_dim)  # view(batch_size, input_dim)
        for idx, layer in enumerate(self.layers):
            if isinstance(layer, GaussianLinear): 
                y = layer.forward_eval(y, sample, previous)
            else: y = layer(y)

        # y = y / y.pow(2).sum(1, keepdim=True).sqrt()

        logits = self.act_func(y)
        if self.use_pred_loss:
            logits = self.fc_pred.forward_eval(logits, sample, previous)
        return y, logits

    def kl_divergence(self):
        kldiv = 0.
        for idx, layer in enumerate(self.layers):
            if isinstance(layer, GaussianLinear): 
                kldiv += layer.kl_divergence()

        if self.use_pred_loss:
            kldiv += self.fc_pred.kl_divergence()

        return kldiv
    
    def kl_divergence_mc(self):
        kldiv = 0.
        for idx, layer in enumerate(self.layers):
            if isinstance(layer, GaussianLinear): 
                kldiv += layer.kl_div_mc()

        if self.use_pred_loss:
            kldiv += self.fc_pred.kl_div_mc()
        return kldiv

    def sample_predict(self, x, Nsamples):
        """Used for estimating the data's likelihood by approximately marginalising the weights with MC"""
        # Just copies type from x, initializes new vector
        predictions = x.data.new(Nsamples, x.shape[0], self.output_dim)
        tlqw_vec = np.zeros(Nsamples)
        tlpw_vec = np.zeros(Nsamples)

        for i in range(Nsamples):
            y, tlqw, tlpw = self.forward(x, sample=True)
            predictions[i] = y
            tlqw_vec[i] = tlqw
            tlpw_vec[i] = tlpw

        return predictions, tlqw_vec, tlpw_vec

class Wrapper_FCBNN(BaseNet):
    eps = 1e-6

    def __init__(self, input_dim, output_dim, lr_hyperparam=None, cuda=True, classes=10, batch_size=128, Nbatches=0,
                 nhid=[1200], prior_instance=None, bias_prior_instance=None, use_pred_loss=None, reweight_constant=1.0, activation_func='tanh'):
        super(Wrapper_FCBNN, self).__init__()
        cprint('y', ' Creating Net!! ')
        self.schedule = None  # [] #[50,200,400,600]
        self.cuda = cuda
        if cuda: self.device = torch.device('cuda')
        else: self.device = torch.device('cpu')
        self.classes = classes
        self.batch_size = batch_size
        self.Nbatches = Nbatches
        self.prior_instance = prior_instance
        self.bias_prior_instance = bias_prior_instance
        self.nhid = nhid
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_pred_loss = use_pred_loss
        self.lr_hyperparam = lr_hyperparam
        self.reweight_constant = reweight_constant
        self.activation_func = activation_func
        self.create_net()
        self.epoch = 0

    def create_net(self):
        self.model = FCBNN(input_dim=self.input_dim, output_dim=self.output_dim, 
                            n_hid=self.nhid, prior_instance=self.prior_instance, 
                            bias_prior_instance=self.bias_prior_instance,
                            num_class=self.classes, use_pred_loss=self.use_pred_loss, 
                            activation_func=self.activation_func)
        if self.cuda:
            self.model.cuda()
        print('    Total params: %.2fM' % (self.get_nb_parameters() / 1000000.0))

    def __str__(self):
        return "{}".format(self.model)

    def create_opt(self, optimizer, scheduler):
        self.optimizer = optimizer
        self.scheduler = scheduler

    def train_batch(self, xijl_batch, targets, loss_func=None, **kwargs):
        n_samples = kwargs['n_samples']
        cnt_batch = kwargs['cnt_batch']
        dataset_size = kwargs['dataset_size']

        xp1, xp2, xn = xijl_batch
        yp, yn = targets

        xp1, xp2, xn, yp, yn = to_variable(var=(xp1, xp2, xn, yp, yn), cuda=self.cuda)

        self.optimizer.zero_grad()

        log_p_Dw_cum = 0
        Edkl_cum = 0
        if self.use_pred_loss:
            loss_pred_cum = 0

        for i in range(n_samples):
            anchor, logits_xp1 = self.model(xp1, sample=True, previous=False)
            positive, logits_xp2 = self.model(xp2, sample=True, previous=True)
            negative, logits_xn = self.model(xn, sample=True, previous=True)

            dfij = (anchor - positive).pow(2).sum(dim=1)
            dfil = (anchor - negative).pow(2).sum(dim=1)
            log_p_Dw_i = loss_func(dfij, dfil)
            log_p_Dw_cum += log_p_Dw_i

            # mask = np.argwhere(log_p_Dw_i.detach().cpu() == 0)[0]

            kldiv_i = self.model.kl_divergence()
            Edkl_cum += kldiv_i

            if self.use_pred_loss:
                logits_trip = (logits_xp1, logits_xp2, logits_xn)
                targets_trip = (yp, yp, yn)
                for (logit_x, target_y) in zip(logits_trip, targets_trip):
                    loss_pred_cum += F.cross_entropy(logit_x, target_y, reduction='mean')

        bad_triplets = len(np.argwhere(log_p_Dw_cum.cpu() == 0.)[0])

        log_p_Dw_cum = log_p_Dw_cum.sum()
        log_p_Dw = log_p_Dw_cum *1./ n_samples
        Edkl = Edkl_cum * 1./ n_samples

        if self.use_pred_loss:
            loss_preds = loss_pred_cum / n_samples
            loss_preds.backward(retain_graph=True)
            # log_p_Dw = log_p_Dw * (self.llh_ratio) + (1 - self.llh_ratio) * loss_preds

        reweight_ratio = (2**(self.Nbatches-cnt_batch-1))/(2**self.Nbatches - 1)
        Edkl = Edkl * reweight_ratio 
        
        loss = self.reweight_constant * Edkl - log_p_Dw

        loss.backward()
        self.optimizer.step()

        if self.use_pred_loss:
            return Edkl.item(), -log_p_Dw.item(), loss_preds.item(), bad_triplets
        return Edkl.item(), -log_p_Dw.item(), bad_triplets

    def eval_pred(self, x, y):
        x, y = to_variable(var=(x,y), cuda=self.cuda)
        _, out = self.model(x)
        loss = F.cross_entropy(out, y, reduction='mean')
        pred = out.data.max(dim=1, keepdim=False)[1]  # get the index of the max log-probability
        err = pred.ne(y.data).sum()
        return loss.data, err

    def eval_dist(self, xp1, xp2, xn, loss_func=None):
        with torch.no_grad():
            xp1, xp2, xn = to_variable(var=(xp1, xp2, xn), cuda=self.cuda)
            out_xp1, _= self.model(xp1)
            out_xp2, _ = self.model(xp2, previous=True)
            out_xn, _ = self.model(xn, previous=True)

            dfij = (out_xp1 - out_xp2).pow(2).sum(dim=1)
            dfil = (out_xp1 - out_xn).pow(2).sum(dim=1)

            log_p_Dw = loss_func(dfij, dfil)
            bad_triplets = len(np.argwhere(log_p_Dw.cpu() == 0.)[0])

            log_p_Dw = log_p_Dw.sum()
            return -log_p_Dw.item(), bad_triplets

    def get_embeddings(self, dataset, sample=False, bs=128):
        self.set_mode_train(train=False)

        embeddings = torch.zeros((len(dataset), self.output_dim))
        num_iter = math.ceil(len(dataset)*1./bs)
        with torch.no_grad():
            for i in range(num_iter):
                if i == 0:  previous = False
                else:       previous=True
                begin = i * bs
                end = begin + bs
                x_batch = dataset[begin:end,:]
                x_batch, = to_variable(var=(x_batch, ), cuda=self.cuda)
                batch_embeddings, _ = self.model.forward_eval(x_batch, sample=sample, previous=previous)
                embeddings[begin:end,:] = batch_embeddings
            return embeddings
