from copy import deepcopy
from typing import Dict, Optional

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel

from args import ModelArguments
from .lfrm_vae import LFRMVAE


class DSLFMKGC(nn.Module):
    def __init__(self, args: ModelArguments) -> None:
        super().__init__()
        self.pooling = args.pooling
        self.log_inv_t = torch.nn.Parameter(torch.tensor(1.0 / args.t).log(), requires_grad=args.finetune_t)
        self.add_margin = args.additive_margin
        self.use_self_negative = args.use_self_negative

        # LFRM vae
        self.lfrm_vae = LFRMVAE(args)
        # BERT encoders
        self.bert_hr = AutoModel.from_pretrained(pretrained_model_name_or_path=args.pretrained_model)
        self.bert_entity = deepcopy(self.bert_hr)

    def forward(
        self,
        hr_token_ids: torch.LongTensor,
        hr_mask: torch.ByteTensor,
        hr_token_type_ids: torch.LongTensor,
        tail_token_ids: torch.LongTensor,
        tail_mask: torch.ByteTensor,
        tail_token_type_ids: torch.LongTensor,
        head_token_ids: torch.LongTensor,
        head_mask: torch.ByteTensor,
        head_token_type_ids: torch.LongTensor,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        hr_vector = self.encode_bert(self.bert_hr, hr_token_ids, hr_mask, hr_token_type_ids)
        tail_vector = self.encode_bert(self.bert_entity, tail_token_ids, tail_mask, tail_token_type_ids)

        hr_lfrm = self.lfrm_vae(hr_vector)
        tail_lfrm = self.lfrm_vae(tail_vector, is_query=False)
        head_lfrm = None
        # for self negative sampling
        if self.training and self.use_self_negative:
            head_vector = self.encode_bert(self.bert_entity, head_token_ids, head_mask, head_token_type_ids)
            head_lfrm = self.lfrm_vae(head_vector, is_query=False)

        return {"hr_lfrm": hr_lfrm, "tail_lfrm": tail_lfrm, "head_lfrm": head_lfrm}

    def encode_bert(self, encoder, token_ids, mask, token_type_ids) -> torch.FloatTensor:
        outputs = encoder(input_ids=token_ids, attention_mask=mask, token_type_ids=token_type_ids)

        last_hidden_state = outputs.last_hidden_state
        cls_output = last_hidden_state[:, 0, :]
        output = _pool_output(self.pooling, cls_output, mask, last_hidden_state)
        return output

    def compute_kl_loss(
        self,
        hr_lfrm: dict,
        tail_lfrm: dict,
        head_lfrm: Optional[dict] = None,
    ) -> Dict[str, torch.Tensor]:
        # compute losses
        loss_hr = self.lfrm_vae.compute_loss(
            hr_lfrm["X"], hr_lfrm["X_recon"], **hr_lfrm["kl_dict"], is_query=True
        )
        loss_tail = self.lfrm_vae.compute_loss(
            tail_lfrm["X"], tail_lfrm["X_recon"], **tail_lfrm["kl_dict"], is_query=False
        )

        loss_lfrm = loss_hr["loss"] + loss_tail["loss"]
        loss_kl = loss_hr["loss_kl"] + loss_tail["loss_kl"]
        loss_recon = loss_hr["loss_recon"] + loss_tail["loss_recon"]
        if self.training and self.use_self_negative:
            # compute losses
            loss_head = self.lfrm_vae.compute_loss(
                head_lfrm["X"], head_lfrm["X_recon"], **head_lfrm["kl_dict"], is_query=False
            )
            loss_lfrm += loss_head["loss"]
            loss_kl += loss_head["loss_kl"]
            loss_recon += loss_head["loss_recon"]

        return {
            "loss_lfrm": loss_lfrm,
            "loss_kl": loss_kl,
            "loss_recon": loss_recon,
        }

    def compute_contrastive_loss(
        self,
        hr_vector: dict,
        tail_vector: dict,
        logits_mask: torch.BoolTensor,
        head_vector: Optional[dict] = None,
        self_neg_logits_mask: Optional[torch.BoolTensor] = None,
    ) -> Dict[str, torch.Tensor]:
        # compute logits for contrastive loss
        hr_vector, tail_vector = F.normalize(hr_vector, dim=1), F.normalize(tail_vector, dim=1)
        logits = hr_vector.mm(tail_vector.t())
        if self.training:
            logits -= torch.zeros(logits.size()).fill_diagonal_(self.add_margin).to(logits.device)
        logits *= self.log_inv_t.exp()

        # self negative
        if self.training and self.use_self_negative:
            # compute contrastive logits
            head_vector = F.normalize(head_vector, dim=1)
            self_neg_logits = torch.sum(hr_vector * head_vector, dim=1) * self.log_inv_t.exp()
            self_neg_logits.masked_fill_(~self_neg_logits_mask, -1e4)
            logits = torch.cat([logits, self_neg_logits.unsqueeze(1)], dim=-1)
        # compute loss
        loss_contrastive = self._contrastive_loss(logits, logits_mask)  # contractive loss

        return {
            "logits": logits.detach(),
            "loss_contrastive": loss_contrastive,
        }

    def _contrastive_loss(
        self, logits: torch.FloatTensor, logits_mask: torch.LongTensor
    ) -> torch.FloatTensor:
        neg_log_prob = -F.log_softmax(logits, dim=1)
        neg_log_prob *= logits_mask.float()
        # neg_log_prob.masked_fill_(~logits_mask, 1e-8)
        num_pos = logits_mask.sum(dim=1)
        loss = neg_log_prob.sum(dim=1) / num_pos.float()

        return loss.mean()


def _pool_output(
    pooling: str, cls_output: torch.Tensor, mask: torch.Tensor, last_hidden_state: torch.Tensor
) -> torch.Tensor:
    if pooling == "cls":
        output_vector = cls_output
    elif pooling == "mean":
        output_vector = torch.sum(last_hidden_state, 1)
    else:
        assert False, "Unknown pooling mode: {}".format(pooling)

    return output_vector
