import math
import torch
import torch.nn.functional as F

from hydra.utils import instantiate
from omegaconf import DictConfig
from transformers import T5ForConditionalGeneration, AutoTokenizer

from src.utils.losses import calc_class_probs


class TaskLanguageModel(torch.nn.Module):
    def __init__(
            self, arch: str, optimizer: DictConfig, io_mode: str,
            lm_mode: str, kd_input: bool, kd_target: bool,
            no_bottleneck: bool,
            ftr_dropout_rate: float,
            **kwargs,
        ):
        super().__init__()
        
        self.arch = arch
        self.optimizer = optimizer
        self.io_mode = io_mode
        self.lm_mode = lm_mode
        self.kd_input, self.kd_target = kd_input, kd_target
        self.no_bottleneck = no_bottleneck
        self.ftr_dropout_rate = ftr_dropout_rate

        # Check that args pass asserts
        assert arch in ['t5-small', 't5-base', 't5-large', 't5-3b']
        assert lm_mode in ['task', 'aux']
        assert (ftr_dropout_rate is None) or (lm_mode == 'task' and io_mode == 'IR-O')

        # Initialize model
        self.lm = T5ForConditionalGeneration.from_pretrained(arch)
        self.lm_dim = self.lm.config.d_model
        self.lm_embed_layer = self.lm.shared

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(arch)
        
        # Initialize projection layers
        self.encoder_projection = None
        self.decoder_projection = None
        self.aux = {
            'lm_head': None,
            'lm_dim': None,
        }
    
    def set_aux_lm(self, aux_lm_dim, aux_lm_head):
        assert self.lm_mode == 'task'
        self.encoder_projection = torch.nn.Linear(self.lm_dim, aux_lm_dim, bias=False)
        self.decoder_projection = torch.nn.Linear(self.lm_dim, aux_lm_dim, bias=False)
        self.aux['lm_head'] = aux_lm_head
        self.aux['lm_dim'] = aux_lm_dim
    
    def set_aux_head(self, aux_lm_dim, aux_lm_head):
        assert self.lm_mode == 'task'
        self.aux['lm_head'] = aux_lm_head
        self.aux['lm_dim'] = aux_lm_dim

    def forward(self, batch, split):
        prefix = '' if self.lm_mode == 'task' else 'aux_'
        example = batch[prefix + 'example']
        example_attn_mask = batch[f'{prefix}example_attn_mask']
        target_seq = batch['target_seq']
        token_type = batch[f'{prefix}token_type']

        # Set encoder/decoder self-attention masks
        if self.io_mode == 'IR-O':
            if self.lm_mode == 'task': # IR-O baseline
                if split == 'train': # During training, do not mask out rationale tokens
                    if self.ftr_dropout_rate is not None:
                        dropped = torch.bernoulli(torch.ones(size=(len(example), ), device=example.device) * self.ftr_dropout_rate).bool()
                        example_attn_mask[dropped] = token_type[dropped]
                    cross_attn_mask = example_attn_mask
                else: # During inference, mask out rationale tokens in both encoder and decoder
                    cross_attn_mask = example_attn_mask = token_type
            else: # Teacher LM
                cross_attn_mask = token_type if not self.no_bottleneck else example_attn_mask # Mask out rationale tokens in decoder only
        else:
            cross_attn_mask = example_attn_mask # No rationale tokens in model input

        if self.io_mode in ['I-OR', 'I-RO']:
            if split == 'train':
                outputs = self.lm(input_ids=example, attention_mask=example_attn_mask, labels=target_seq)
            else:
                outputs = self.lm.generate(input_ids=example, attention_mask=example_attn_mask, return_dict_in_generate=True)
        else:
            batch_size, num_classes = target_seq.shape[: -1]

            def copy_num_classes(inputs):
                inputs = inputs.unsqueeze(dim=1) # inputs.shape = (batch_size, 1, input_len)
                inputs = torch.cat([inputs]*num_classes, dim=1) # inputs.shape = (batch_size, num_classes, input_len)
                inputs = inputs.reshape(batch_size*num_classes, -1) # inputs.shape = (batch_size*num_classes, input_len)
                return inputs

            # Expand encoder inputs w.r.t. num_classes
            example = copy_num_classes(example)
            example_attn_mask = copy_num_classes(example_attn_mask)
            cross_attn_mask = copy_num_classes(cross_attn_mask)
            target_seq = target_seq.reshape(batch_size*num_classes, -1)

            # Compute encoder outputs
            outputs = self.lm(encoder_outputs=self.lm.encoder(
                input_ids=example, attention_mask=example_attn_mask, output_hidden_states=True),
                attention_mask=cross_attn_mask, labels=target_seq,
                output_hidden_states=self.kd_input or self.kd_target,
            )
            
            if self.aux['lm_head'] is not None:
                sequence_output = outputs.decoder_hidden_states[-1]
                if self.decoder_projection is not None:
                    sequence_output = self.decoder_projection(sequence_output) # Apply projection to decoder output

                # Rescale output before projecting on vocab
                # Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L1671-L1676
                if self.lm.config.tie_word_embeddings:
                    sequence_output = sequence_output * (self.aux['lm_dim']**-0.5)
                outputs.logits = self.aux['lm_head'](sequence_output)

            # Compute class probabilities
            outputs.logits = calc_class_probs(outputs, target_seq, num_classes)

            if self.kd_input: # Compute kd input states
                token_type = copy_num_classes(token_type)
                states = outputs.encoder_hidden_states[-1] # states.shape = (new_bs = batch_size*num_classes, length, lm_dim)
                if self.encoder_projection is not None:
                    states = self.encoder_projection(states)
                outputs.kd_input_states = states[token_type.bool()] # outputs.kd_input_states.shape = (task_input_token_numbers, lm_dim)

            if self.kd_target: # Compute kd target states
                states = outputs.decoder_hidden_states[-1] # states.shape = (new_bs = batch_size*num_classes, length, lm_dim)
                if self.decoder_projection is not None:
                    states = self.decoder_projection(states)
                outputs.kd_target_states = states[target_seq != -100] # outputs.kd_target_states.shape = (not_pad_token_numbers, lm_dim)

        return outputs