import torch
import torch.nn as nn
from loguru import logger
from utils.helper_utils import (
    sequence_mask,
)

class GPT2Captions(nn.Module):
    def __init__(self, model, tokenizer, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.model = model

    def forward(self, batch):

        eos_token_id = self.model.config.pad_token_id

        captions = batch["captions"]
        captions_length = batch["captions_length"]

        device = captions.device

        input_captions = captions[:, :-1, :]
        input_captions_length = captions_length[:, :-1]
        num_captions_history = input_captions.shape[-2]
        input_captions = input_captions.view(input_captions.shape[0], -1)
        input_captions_length = torch.sum(input_captions_length, dim=1)
        # concatenate all the first 4 captions respecting the lengths of the individual captions
        max_len = torch.max(input_captions_length)
        # concatenate the input captions with the distractors
        input_captions_concat = (
            torch.zeros(input_captions_length.shape[0], max_len).long().to(device)
        )
        sequence_masks = []
        for i in range(num_captions_history):
            sequence_masks.append(
                sequence_mask(
                    captions_length[:, i],
                    max_len=self.cfg["TV"]["MAX_LEN_CAPTION"],
                )
            )
        sequence_masks = torch.cat(sequence_masks, dim=1)
        assert torch.equal(torch.sum(sequence_masks, dim=1), input_captions_length)

        input_captions_concat[
            sequence_mask(input_captions_length, max_len=max_len)
        ] = input_captions[sequence_masks]
        input_captions = input_captions_concat

        # get the context by passing input captions
        input_caption_mask = sequence_mask(
            input_captions_length, max_len=max_len
        )

        # add eos token to demarcate the last token
        input_captions[~input_caption_mask] = eos_token_id

        outputs = self.model(
            input_ids=input_captions,
            attention_mask=input_caption_mask,
        )
        logits = outputs.logits
        assert logits.shape[1] == 4
        return logits

class RobertaCaptions(nn.Module):
    def __init__(self, model, tokenizer, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.model = model

    def forward(self, batch):
        captions = batch["captions"]
        captions_length = batch["captions_length"]

        device = captions.device
        input_captions = captions[:, :-1, :]
        input_captions_length = captions_length[:, :-1]
        num_captions_history = input_captions.shape[-2]
        input_captions = input_captions.view(input_captions.shape[0], -1)
        input_captions_length = torch.sum(input_captions_length, dim=1)
        # concatenate all the first 4 captions respecting the lengths of the individual captions
        max_len = torch.max(input_captions_length)
        # concatenate the input captions with the distractors
        input_captions_concat = (
            torch.zeros(input_captions_length.shape[0], max_len).long().to(device)
        )
        sequence_masks = []
        for i in range(num_captions_history):
            sequence_masks.append(
                sequence_mask(
                    captions_length[:, i],
                    max_len=self.cfg["TV"]["MAX_LEN_CAPTION"],
                )
            )
        sequence_masks = torch.cat(sequence_masks, dim=1)
        assert torch.equal(torch.sum(sequence_masks, dim=1), input_captions_length)

        input_captions_concat[
            sequence_mask(input_captions_length, max_len=max_len)
        ] = input_captions[sequence_masks]
        input_captions = input_captions_concat

        outputs = self.model(
            input_ids=input_captions,
            attention_mask=sequence_mask(input_captions_length, max_len=max_len),
        )

        logits = outputs.logits

        assert logits.shape[1] == 4
        return logits

class BertCaptions(nn.Module):
    def __init__(self, model, tokenizer, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.model = model

    def forward(self, batch):
        captions = batch["captions"]
        captions_length = batch["captions_length"]
       
        device = captions.device
        input_captions = captions[:, :-1, :]
        input_captions_length = captions_length[:, :-1]
        num_captions_history = input_captions.shape[-2]
        input_captions = input_captions.view(input_captions.shape[0], -1)
        input_captions_length = torch.sum(input_captions_length, dim=1)
        # concatenate all the first 4 captions respecting the lengths of the individual captions
        max_len = torch.max(input_captions_length)
        # concatenate the input captions with the distractors
        input_captions_concat = (
            torch.zeros(input_captions_length.shape[0], max_len).long().to(device)
        )
        sequence_masks = []
        for i in range(num_captions_history):
            sequence_masks.append(
                sequence_mask(
                    captions_length[:, i],
                    max_len=self.cfg["VIST"]["MAX_LEN_CAPTION"],
                )
            )
        sequence_masks = torch.cat(sequence_masks, dim=1)
        assert torch.equal(torch.sum(sequence_masks, dim=1), input_captions_length)

        input_captions_concat[
            sequence_mask(input_captions_length, max_len=max_len)
        ] = input_captions[sequence_masks]
        input_captions = input_captions_concat

        CLS_token = self.tokenizer.encode(
            self.tokenizer.cls_token, add_special_tokens=False
        )[0]
        SEP_token = self.tokenizer.encode(
            self.tokenizer.sep_token, add_special_tokens=False
        )[0]

        input_captions = torch.cat(
            [
                torch.ones(input_captions.shape[0], 1).long().to(device)
                * CLS_token,
                input_captions,
            ],
            dim=1,
        )
        input_captions = torch.cat(
            [
                input_captions,
                torch.zeros(input_captions.shape[0], 1).long().to(device),
            ],
            dim=1,
        )
        input_captions[
            torch.arange(input_captions.shape[0]), input_captions_length + 1
        ] = SEP_token

        input_captions_length = input_captions_length + 2

        outputs = self.model(
            input_ids=input_captions,
            attention_mask=sequence_mask(input_captions_length, max_len=torch.max(input_captions_length)),
        )
        logits = outputs.logits
        assert logits.shape[1] == 4
        return logits