import logging
import pdb

import torch
from transformers import T5ForConditionalGeneration

logger = logging.getLogger(__name__)


class T5ForSequenceClassification(T5ForConditionalGeneration):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_labels = self.config.num_labels
        #self.label2id = [32099, 32098]
        self.label2id = [2841, 1465]

        # AGN
        if self.num_labels == 4:
            self.label2id = [268, 5256, 2100, 296]

    def forward(self, **kwargs):
        kwargs.pop('token_type_ids', None)
        labels = kwargs.pop('labels', None)
        kwargs['labels'] = torch.tensor([[self.label2id[label]] for label in labels], device=labels.device, dtype=torch.long)
        kwargs['return_dict'] = True
        outputs = super().forward(**kwargs)
        return outputs.loss, outputs.logits[:, 0, torch.LongTensor(self.label2id)]

class T5ForMultipleChoice(T5ForConditionalGeneration):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_labels = self.config.num_labels
        #self.label2id = [32099, 32098]
        self.label2id = [[2841], [1465]]

        # AGN
        if self.num_labels == 4:
            raise NotImplementedError
            self.label2id = [268, 5256, 2100, 296]

    def generate(self, *args, **kwargs):
        raise NotImplementedError

    def forward(self, output_mc_logits=True, **kwargs):
        kwargs.pop('token_type_ids', None)
        labels = kwargs.pop('labels', None)
        choice_seqs = kwargs.pop('choices_input_ids', None)
        kwargs['labels'] = choice_seqs.gather(1, labels.reshape((choice_seqs.shape[0], 1, 1)).repeat((1,1,choice_seqs.shape[2]))).reshape((choice_seqs.shape[0], choice_seqs.shape[2]))
        kwargs['return_dict'] = True
        outputs = super().forward(**kwargs)

        true_loss = outputs.loss

        # Get logit for each answer choice
        all_logits = None
        new_kwargs = kwargs.copy()
        new_kwargs['encoder_outputs'] = (outputs.encoder_last_hidden_state, outputs.encoder_hidden_states, outputs.encoder_attentions)
        for choice_idx in range(choice_seqs.shape[1]):
            new_kwargs['labels'] = choice_seqs.index_select(1, torch.tensor([choice_idx], dtype=torch.long, device=choice_seqs.device)).squeeze(1)
            new_outputs = super().forward(**new_kwargs)

            new_outputs.logits = torch.log(torch.softmax(new_outputs.logits, dim=-1))
            choice_logits = new_outputs.logits.gather(2, new_kwargs['labels'].unsqueeze(-1)).squeeze(-1)

            # Create mask so we only sum logits that are part of the sequence
            choice_seqs_mask = torch.tensor(new_kwargs['labels'] != 0, dtype=torch.long)

            # Now we have a single aggregate logit for each choice
            choice_logits = (choice_logits*choice_seqs_mask).sum(dim=-1) / choice_seqs_mask.sum(dim=-1)
            if all_logits is None:
                all_logits = choice_logits.unsqueeze(1)
            else:
                all_logits = torch.hstack([all_logits, choice_logits.unsqueeze(1)])

        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            outputs.loss = loss_fct(all_logits, labels)

        if output_mc_logits:
            return outputs.loss, all_logits

        return outputs.loss

