from copy import deepcopy
from typing import Dict, Optional

import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel

from args import ModelArguments
from .lfrm_vae import LFRMVAE


class DSLFMKGCPredictor(nn.Module):
    def __init__(self, args: ModelArguments) -> None:
        super().__init__()
        self.pooling = args.pooling
        self.log_inv_t = torch.nn.Parameter(torch.tensor(1.0 / args.t).log(), requires_grad=args.finetune_t)
        self.add_margin = args.additive_margin
        self.use_self_negative = args.use_self_negative

        # LFRM vae
        self.lfrm_vae = LFRMVAE(args)
        # BERT encoders
        self.bert_hr = AutoModel.from_pretrained(pretrained_model_name_or_path=args.pretrained_model)
        self.bert_entity = deepcopy(self.bert_hr)

    def forward(
        self,
        hr_token_ids: Optional[torch.LongTensor] = None,
        hr_mask: Optional[torch.ByteTensor] = None,
        hr_token_type_ids: Optional[torch.LongTensor] = None,
        tail_token_ids: Optional[torch.LongTensor] = None,
        tail_mask: Optional[torch.ByteTensor] = None,
        tail_token_type_ids: Optional[torch.LongTensor] = None,
        encode_hr: bool = False,
    ) -> Dict[str, torch.Tensor]:
        if encode_hr:
            hr_vector = self.encode_bert(self.bert_hr, hr_token_ids, hr_mask, hr_token_type_ids)
            hr_lfrm = self.lfrm_vae(hr_vector)
            return hr_lfrm, F.normalize(hr_lfrm["X_recon"], dim=1)

        # encode entities
        tail_vector = self.encode_bert(self.bert_entity, tail_token_ids, tail_mask, tail_token_type_ids)
        tail_lfrm = self.lfrm_vae(tail_vector, is_query=False)
        return tail_lfrm, F.normalize(tail_lfrm["X_recon"], dim=1)

    def encode_bert(self, encoder, token_ids, mask, token_type_ids) -> torch.FloatTensor:
        outputs = encoder(input_ids=token_ids, attention_mask=mask, token_type_ids=token_type_ids)

        last_hidden_state = outputs.last_hidden_state
        cls_output = last_hidden_state[:, 0, :]
        output = _pool_output(self.pooling, cls_output, mask, last_hidden_state)
        return output


def _pool_output(
    pooling: str, cls_output: torch.Tensor, mask: torch.Tensor, last_hidden_state: torch.Tensor
) -> torch.Tensor:
    if pooling == "cls":
        output_vector = cls_output
    elif pooling == "mean":
        input_mask_expanded = mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-4)
        output_vector = sum_embeddings / sum_mask
    else:
        assert False, "Unknown pooling mode: {}".format(pooling)

    return output_vector
