from __future__ import print_function

import torch
import torch.nn as nn
import numpy as np
from numpy import linalg as LA
import torch.nn.functional as F

import matplotlib.pyplot as plt


class ConditionalSamplingLoss(nn.Module):
    """Continuous version for conditional sampling loss"""
    def __init__(self, temperature=0.1, mode='hardnegatives',
                 temp_z=0.1, scale=1, lambda_=0.1, lsh = 'off',
                 weight_clip_threshold=1e-6, inverse_gradient=False):
        super(ConditionalSamplingLoss, self).__init__()
        self.temp_z = temp_z
        self.lambda_ = lambda_
        self.ce_loss = nn.CrossEntropyLoss()
        self.mode = mode
        self.scale = scale
        self.cosinesim = nn.CosineSimilarity(dim=-1)
        self.weight_clip_threshold=weight_clip_threshold
        self.inverse_gradient = inverse_gradient
        self.softmax = nn.Softmax(dim = -1)
        self.lsh = lsh


    def forward(self, raw_score, z_score, high_threshold=0.8, low_threshold=0.2, device='cuda:0', warmup=False):
        """
        raw_score: [2n, 2n],
        condition: [n, z_dim]

        1) Compute M = K_XY (K_Z + lambda I)^-1 K_Z
        2) build conditional sampling loss

        return loss (scalar)
        """
        n = int(raw_score.shape[0] / 2)

        if warmup:
            # use simclr to warmup for all cases
            targets = torch.arange(2 * n, dtype=torch.long, device=raw_score.device)
            loss = self.ce_loss(raw_score, targets)
            return loss


        Kxy = torch.exp(raw_score[:n, :n])
        Kxx = torch.exp(raw_score[:n, n:])
        Kyy = torch.exp(raw_score[n:, :n])
        Kyx = torch.exp(raw_score[n:, n:])

        if self.lsh != 'off':
            # if lsh on, then the z_score fed in is the ca_out calculation from 
            # compute_lsh_conditional_att
            Mxy, Mxx, Myy, Myx = z_score


        else:
            Kz1z2 = z_score[:n, :n]
            Kz1z1 = z_score[:n, n:]
            Kz2z2 = z_score[n:, :n]
            Kz2z1 = z_score[n:, n:]
            
            Mxy = torch.matmul(Kxy, Kz1z2)
            Mxx = torch.matmul(Kxx, Kz1z1)
            Myx = torch.matmul(Kyx, Kz2z1)
            Myy = torch.matmul(Kyy, Kz2z2) 

        # loss

        # pos
        pos = torch.diagonal(raw_score[:n, :n], 0) # n,
            
        # negatives
        deno = torch.clamp(torch.exp(pos) + (n - 1) * (torch.diagonal(Mxy, 0) + torch.diagonal(Mxx, 0)), 1e-7, 1e+20)
        log_negatives = torch.log(deno) # n

        loss_x = - (pos - log_negatives).mean()

        pos = torch.diagonal(raw_score[n:, n:], 0) # n,
        deno = torch.clamp(torch.exp(pos) + (n - 1) * (torch.diagonal(Myx, 0) + torch.diagonal(Myy, 0)), 1e-7, 1e+20)
        log_negatives = torch.log(deno) # n

        loss_y = - (pos - log_negatives).mean()

        loss = (loss_x + loss_y) / 2


        return loss

    def compute_approximation(self, triple):
        U, D, V = triple[0], torch.diag(triple[1]), triple[2]

        return torch.matmul(torch.matmul(U, D), V.T)


    def normalization(self, distance):
        #return (distance - distance.mean()) / distance.std()
        return F.softmax(distance, dim=-1)




