from .utils import *
from .base import *
import torch
import torch.nn.functional as F

import numpy as np
from abc import *
import pdb

def get_input_seq_template(args, seq, meta):
    if args.dataset_code in ['beauty', 'games', 'auto', 'toys_new', 'sports', 'office']:
        prompt = "query: "
        for item in seq:
            try:
                title = meta[item][0]
            except:
                import pdb; pdb.set_trace()
            prompt += "'"
            prompt += title
            prompt += "', a product from "

            category = meta[item][2]
            category = ', '.join(category.split(', ')[-2:])

            prompt += category
            prompt += ' category. \n'

    return prompt

class E5Trainer_non_cl(BaseTrainer):
    def __init__(self, args, model, meta, train_loader, val_loader, DP_loader):
        super().__init__(args, model, train_loader, val_loader, DP_loader)
        # get dataloader
        self.meta = meta

        self.ce = torch.nn.CrossEntropyLoss()
        self.mr = torch.nn.MarginRankingLoss(self.args.margin)
    
    def calculate_loss(self, batch):
        user_id, seq, purchase_tokens, input_text, _, batch_size = batch
        input_text_dp = self.prepare_DP_samples(user_id)

        input_tokens_all = self.prepare_tokens(input_text, input_text_dp)

        # calculate embeddings
        p_all, z_all = self.calculate_embeddings(input_tokens_all)
        
        # compute alignment loss
        loss_a = self.get_alignment_loss(p_all[:batch_size], z_all[batch_size:]) + \
                    self.get_alignment_loss(p_all[batch_size:], z_all[:batch_size])
        
        # compute uniformity loss
        loss_u = self.get_uniformity_loss(p_all[:batch_size], z_all[batch_size:]) + \
                    self.get_uniformity_loss(p_all[batch_size:], z_all[:batch_size])
        
        loss = self.get_contrastive_loss_purchase(purchase_tokens, batch_size) + self.args.mce_gamma_a * loss_a + self.args.mce_gamma_u * loss_u

        return loss
    
    def calculate_embeddings(self, input_tokens_all):        
        p_all = self.model.forward_target(input_tokens_all) # no grad
        z_all = self.model.forward_online(input_tokens_all)
        return p_all, z_all
    
    def get_contrastive_loss_purchase(self, purchases_contrastive, batch_size):
        embeddings = self.model.forward_model(purchases_contrastive).reshape(batch_size, (1 + 1 + self.args.num_non_purchase), -1)
        
        # cosine between positive pairs
        positive_logit = torch.sum(embeddings[:, 0] * embeddings[:, 1], dim=-1, keepdim=True)
        # cosine between negative pairs
        negative_logit = torch.sum(embeddings[:, 0].unsqueeze(1) * embeddings[:, 2:], dim=-1)

        logits = torch.cat([negative_logit, positive_logit], dim=-1) / 0.01
    
        labels = torch.ones(len(logits)).long().to(self.local_device) * (logits.shape[-1] - 1) 

        # cross entropy
        loss = self.ce(logits, labels)
        
        return loss
    
    def get_alignment_loss(self, p, z):
        m = z.shape[0]
        n = z.shape[1]
        # print(m, n)
        J_m = self.centering_matrix(m).detach().to(z.device)

        P = (1. / m) * (p.T @ J_m @ p) + self.args.mce_mu * torch.eye(n).to(z.device)
        Q = (1. / m) * (z.T @ J_m @ z) + self.args.mce_mu * torch.eye(n).to(z.device)
        
        return torch.trace(- P @ self.matrix_log(Q, self.args.mce_order))

    def get_uniformity_loss(self, p, z):
        m = z.shape[0]
        n = z.shape[1]
        # print(m, n)
        J_m = self.centering_matrix(m).detach().to(z.device)

        P = self.args.mce_lamda * torch.eye(n).to(z.device)
        Q = (1. / m) * (p.T @ J_m @ z) + self.args.mce_mu * torch.eye(n).to(z.device)

        return torch.trace(- P @ self.matrix_log(Q, self.args.mce_order))
    
    def centering_matrix(self, m):
        J_m = torch.eye(m) - (torch.ones([m, 1]) @ torch.ones([1, m])) * (1.0 / m)
        return J_m

    # Taylor expansion
    def matrix_log(self, Q, order=4):
        n = Q.shape[0]
        Q = Q - torch.eye(n).detach().to(Q.device)
        cur = Q
        res = torch.zeros_like(Q).detach().to(Q.device)
        for k in range(1, order + 1):
            if k % 2 == 1:
                res = res + cur * (1. / float(k))
            else:
                res = res - cur * (1. / float(k))
            cur = cur @ Q

        return res

    def calculate_metrics(self, batch):
        batch_size = batch[-1]
        embeddings = self.model.forward_model(batch)
        
        embeddings = embeddings.reshape(batch_size, (1 + 1 + self.args.num_non_purchase), -1)

        # cosine between positive pairs
        positive_logit = torch.sum(embeddings[:, 0] * embeddings[:, 1], dim=-1, keepdim=True)
        
        # cosine between negative pairs
        negative_logit = torch.sum(embeddings[:, 0].unsqueeze(1) * embeddings[:, 2:], dim=-1)

        scores = torch.cat([negative_logit, positive_logit], dim=-1) / 0.01
        labels = torch.ones(len(scores)).long().to(self.local_device) * (scores.shape[-1] - 1)

        metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels.view(-1), self.metric_ks)

        return metrics

    def prepare_DP_samples(self, user_id):
        input_text_dp = []
        for uid in user_id:
            dp_seq = self.dp_seqs[uid] 
            dp_text = get_input_seq_template(self.args, dp_seq, self.meta)
            input_text_dp.append(dp_text)

        return input_text_dp

    def prepare_tokens(self, input_text, input_text_dp):
        text_all = input_text + input_text_dp
        input_tokens_dp = self.model.tokenizer(text_all, max_length=256, truncation=True, padding=True, return_tensors="pt")
        
        return (input_tokens_dp['input_ids'], input_tokens_dp['token_type_ids'], input_tokens_dp['attention_mask'])