from mucoco.losses import BaseLoss, register_loss


import torch 
import torch.nn.functional as F

# def squared_cdist(x, y):
#     # |x_i - y_j|_2^2 = <x_i - y_j, x_i - y_j> = <x_i, x_i> + <y_j, y_j> - 2*<x_i, y_j>
#     x_sq_norm = x.square().sum(dim=-1, keepdim=True)
#     y_sq_norm = y.square().sum(dim=-1, keepdim=True)
#     x_dot_y = x.bmm(E.t().unsqueeze(0))
#     sq_dist = x_sq_norm + y_sq_norm - 2*x_dot_y
#     # print(sq_dist.size())
#     return sq_dist
#     # For numerical issues
#     # sq_dist.clamp_(min=0.0)
#     # return torch.sqrt(sq_dist)

@register_loss("uniquel2")
class UniqueNessLoss(BaseLoss):

    def __init__(self, model, tokenizer, args):
        super().__init__() 

        self.model = model 
        self.tokenizer = tokenizer 
        self.args = args
        self.device = model.device

        self.bos_token_id = self.tokenizer.bos_token_id
        self.eos_token_id = self.tokenizer.eos_token_id    
        
        self.C = (torch.tril(torch.ones((args.max_output_length, args.max_output_length))) - torch.eye(args.max_output_length, args.max_output_length)).to(self.device).detach()
    
    def compute_loss(self, batch, preds, **kwargs):
        '''
        batch: a tuple (source, prefix). If giving a prompt to the decoder, it can be specified using "prefix"
        preds: a tuple containing (predicted tokens, predicted embeddings, predicted probabilities), this is obtained through a forward pass on the optimizable target parameters (See utils/target.py)
        '''
        prompt, prefix = batch #prompt is the real deal, prefix can be provided as an extended prompt (generated by the model autoregressively)
        pred_tokens, pred_embeds, pred_probs = preds
        pred_probs = pred_probs[0]
        batch_size = prompt.size(0)

        embed_lut = self.model.get_input_embeddings()
        # print(prefix.size(), pred_tokens.size())
        input_tokens = torch.cat([prefix, pred_tokens], dim=1)
        input_embeds = torch.cat([embed_lut(prefix), pred_embeds], dim=1)

        dist = torch.cdist(input_embeds, input_embeds).square()
        # print(dist)
        
        C = self.C[:input_embeds.size(1), :input_embeds.size(1)]
        # print(C)

        # ndist_exp = torch.exp(-dist)
        # ndist_exp_c = (ndist_exp * C.unsqueeze(0))
        # # print(ndist_exp_c)
        # # input()
        # loss = ndist_exp_c.sum(dim=-1).sum(dim=-1)
        # l = input_embeds.size(1)
        # deno = l*(l-1)/2

        # loss = loss/deno

        ndist_q = F.softmax(-dist, dim=-1)
        # print(ndist_q)

        unll_q = (ndist_q * C.unsqueeze(0)) / ((ndist_q * C.unsqueeze(0)).sum(dim=-1, keepdims=True) + 1e-8)

        # print(unll_q)
        unll = dist * unll_q
        
        # print(unll)
        unll = unll.sum(dim=-1)
        # print(unll)
        
        unll = unll[:, 1:]
        
        tau = 1.0
        unll_qq = F.softmax(-unll/tau, dim=-1)
        # print(unll_qq)

        loss = -(unll_qq * unll).sum(dim=-1)
        # loss = -unll.sum(dim=-1)
        # input(loss)

        # print(loss)
        # input(loss)
        logging_output = {
            "loss": loss.data.cpu()
        }

        logging_output = {"loss":loss.data.cpu()}
        return loss, logging_output

    def compute_gold_loss(self, batch, **kwargs):
        '''
        given a discrete target output, this will compute the loss wrt to it. Useful in debugging
        '''
        prompt, target = batch
        batch_size = prompt.size(0)

        # embed_lut = 
        input_embeds = self.model.get_input_embeddings()(target)
        # input_tokens = target

        dist = torch.cdist(input_embeds, input_embeds)#.square()
        # print(dist)
        
        C = self.C[:input_embeds.size(1), :input_embeds.size(1)]
        # print(C)

        ndist_q = F.softmax(-dist, dim=-1)
        # print(ndist_q)

        unll_q = (ndist_q * C.unsqueeze(0)) / ((ndist_q * C.unsqueeze(0)).sum(dim=-1, keepdims=True) + 1e-8)

        # print(unll_q)
        unll = dist * unll_q
        
        # print(unll)
        unll = unll.sum(dim=-1)
        # print(unll)
        # 
        unll = unll[:, 1:]
        
        tau = 1.0
        unll_qq = F.softmax(-unll/tau, dim=-1)
        # print(unll_qq)

        loss = -(unll_qq * unll).sum(dim=-1)
        # input(loss)
        # print(target)
        # print(dist)
        # print(unll)
        # loss = -unll.sum(dim=-1)
        # input(loss)
        logging_output = {
            "loss": loss.data.cpu()
        }

        logging_output = {"loss":loss.data.cpu()}

        return loss, logging_output   
    
