import torch
from torch import nn
from config import cfg



class SP_Loss(object):
    
    def __init__(self,
                 z1, z2, z3,
                 t1, t2, t3, total_t,
                 z_first, z_last,
                 sigma_inv
                 ):
        self.z1 = z1
        self.z2 = z2
        self.z3 = z3
        
        # add one dimension
        self.t1 = t1.unsqueeze(1)
        self.t2 = t2.unsqueeze(1)
        self.t3 = t3.unsqueeze(1)
        self.total_t = total_t.unsqueeze(1)
        
        self.z_first = z_first
        self.z_last = z_last
        self.sigma_inv = sigma_inv
        
        self.bsz, self.dim = z1.shape
        
    def neg_log_likelihood(self):
        
        sigma_inv = self.sigma_inv.repeat(self.bsz, 1, 1) # bsz x d x d
        
        all_z = torch.stack([self.z1, self.z2, self.z3], dim=2)
        mu1 = self.z_first + (self.z_last - self.z_first) * self.t1 / self.total_t
        mu2 = self.z_first + (self.z_last - self.z_first) * self.t2 / self.total_t
        mu3 = self.z_first + (self.z_last - self.z_first) * self.t3 / self.total_t
        all_mu = torch.stack([mu1, mu2, mu3], dim=2)
        
        z_mu_diff = all_z - all_mu # bsz x d x 3
        
        cov_mat = torch.stack(
            [self.t1*(self.total_t - self.t1)/self.total_t, self.t1*(self.total_t - self.t2)/self.total_t, self.t1*(self.total_t - self.t3)/self.total_t,
            self.t1*(self.total_t - self.t2)/self.total_t, self.t2*(self.total_t - self.t2)/self.total_t, self.t2*(self.total_t - self.t3)/self.total_t,
            self.t1*(self.total_t - self.t3)/self.total_t, self.t2*(self.total_t - self.t3)/self.total_t, self.t3*(self.total_t - self.t3)/self.total_t],
            dim=1
            ).view(self.bsz, 3, 3)

        cov_mat_inv = torch.inverse(cov_mat) # bsz x 3 x 3
        
        result_mat = sigma_inv @ z_mu_diff @ cov_mat_inv @ z_mu_diff.transpose(1, 2) # bsz x d x d
        
        loss = torch.vmap(torch.trace)(result_mat)
        
        return loss.sum()
    
    def get_loss(self):
        return self.neg_log_likelihood()


class CL_Loss(object):
    """Everything is a brownian bridge...

    p(z_t | mu_0, mu_T) = \mathcal{N}(mu_0 * t/T + mu_T * (1-t/T), I t*(T-t)/T)

    normalization constant: -1/(2 * t*(T-t)/T)
    """

    def __init__(self,
                 z_0, z_t, z_T,
                 t_, t, T,
                 alpha, var,
                 # log_q_y_T,
                 eps,
                 max_seq_len,
                 C_eta=None,
                 label=None):
        super().__init__()
        # self.log_q_y_T = log_q_y_T
        self.z_0 = z_0
        self.z_t = z_t
        self.z_T = z_T
        self.t_ = t_
        self.t = t
        self.T = T
        self.alpha = alpha
        self.var = var
        self.loss_f = self.simclr_loss
        self.eps= eps
        self.max_seq_len = max_seq_len
        self.sigmoid = nn.Sigmoid()
        self.label = label

        if C_eta is None:
            C_eta = 0.0
        self.C_eta = C_eta
        self.end_pin_val = 1.0
        self.device = cfg['experiment_params']['device']

    def _log_p(self, z_0, z_t, z_T, t_0, t_1, t_2):
        T = t_2-t_0
        t = t_1-t_0

        alpha = (t/(T + self.eps)).view(-1, 1)
        delta = z_0 * (1-alpha) + z_T * (alpha) - z_t
        var = (t * (T - t)/ (T + self.eps))
        log_p = -1/(2*var + self.eps) * (delta*delta).sum(-1) + self.C_eta # (512,)
        if len(log_p.shape) > 1: # (1, bsz)
            log_p = log_p.squeeze(0)
        return log_p

    def _logit(self, z_0, z_T, z_t, t_, t, T):
        """
        Calculating log p(z_tp1, z_t) = -|| h(z_{t+dt}) - h(z_t)(1-dt)||^2_2
        """
        log_p = self._log_p(z_0=z_0, z_t=z_t, z_T=z_T,
                            t_0=t_, t_1=t, t_2=T)
        log_p = log_p.unsqueeze(-1)
        # log_q = self.log_q_y_T
        logit = log_p # - log_q
        return logit # should be (bsz, 1)

    def reg_loss(self):
        loss = 0.0
        mse_loss_f = nn.MSELoss()
        # start reg
        start_idxs = torch.where((self.t_) == 0)[0]
        if start_idxs.nelement():
            vals = self.z_0[start_idxs, :]
            start_reg = mse_loss_f(vals, torch.zeros(vals.shape, device=self.device))
            loss += start_reg
        # end reg
        end_idxs = torch.where((self.T) == self.max_seq_len - 1)[0]
        if end_idxs.nelement():
            vals = torch.abs(self.z_T[end_idxs, :])
            end_reg = mse_loss_f(vals, torch.ones(vals.shape, device=self.device)*self.end_pin_val)
            loss += end_reg
        return loss


    def simclr_loss(self):
        """
        log p = -1/(2*eta) \| x' - x - \mu(x) \|^2_2 + C_{\eta}

        logit = log p - log q
        """
        loss = 0.0
        # Positive pair
        pos_logit = self._logit(z_0=self.z_0, z_T=self.z_T, z_t=self.z_t,
                                t_=self.t_, t=self.t, T=self.T)
        pos_probs = torch.exp(pos_logit) # (bsz,1)
        for idx in range(self.z_T.shape[0]):
            # Negative pair: logits over all possible contrasts
            # Nominal contrast for random triplet - contrast from in between
            neg_i_logit = self._logit(
                z_0=self.z_0, z_T=self.z_T, z_t=self.z_t[idx],
                t_=self.t_, t=self.t[idx], T=self.T)
            neg_i_probs = torch.exp(neg_i_logit) # (bsz,1)
            loss_i = -(pos_logit[idx] - torch.log(neg_i_probs.sum() + self.eps))
            loss += loss_i

        loss = loss / self.z_T.shape[0]
        # Regularization for pinning start and end of bridge
        reg_loss = self.reg_loss()
        loss += reg_loss
        return loss

    def get_loss(self):
        return self.loss_f()