import torch
import numpy as np
from tqdm import tqdm
from transformers import BartTokenizer, BartForConditionalGeneration


class BartModelWrapper:
    def __init__(self, root_dir, device, variant="bart-large"):
        self.max_length = 35
        
        self.variant = variant
        self.root_dir = root_dir
        self.tokenizer = BartTokenizer.from_pretrained(f"facebook/{variant}")
        self.model = BartForConditionalGeneration.from_pretrained(f"facebook/{variant}").to(device)
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=self.model.config.pad_token_id)
        self.device = device
    
    
    @torch.no_grad()
    def get_bartscore(self, text):
        """
        Args: text: list of text
        returns: Tensor

        """
        encoded_tgt = self.tokenizer(
            text,
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors='pt'
        )

        tgt_tokens = encoded_tgt['input_ids'].to(self.device)
        tgt_mask = encoded_tgt['attention_mask']
        tgt_len = tgt_mask.sum(dim=1).to(self.device)
        output = self.model(
            input_ids=self.src_tokens.expand(len(tgt_len), *self.src_tokens.shape[1:]),
            attention_mask=self.src_mask.expand(len(tgt_len), *self.src_mask.shape[1:]),
            labels=tgt_tokens
        )

        logits = output.logits.view(-1, self.model.config.vocab_size)
        loss = self.loss_fct(logits, tgt_tokens.view(-1))
        loss = loss.view(tgt_tokens.shape[0], -1)
        loss = loss.sum(dim=1) / tgt_len
        return (-loss).exp()
    
    def run_scores_batched(self, num_test, num_image_options, text):
        # Should return something with shape (n_tests, n_image_options, n_text_options)
        # Image embeds and all: (n_tests, n_image_options, embed_dim)
        # Text embeds and all: (n_tests, n_text_options, embed_dim)
        
        # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
        num_text_options = len(text)
        score_matrix_i2t = torch.full((num_test, 1, num_text_options),-100.0).to(self.device)
    
        for i in range(num_test): 
            lm_score = self.get_bartscore([t[i] for t in text])
            score_matrix_i2t[i,0] = lm_score
    
        score_matrix_i2t = score_matrix_i2t.repeat(1, num_image_options, 1)

        score_matrix_t2i = score_matrix_i2t.permute(0,2,1)
        return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
        
        
    @torch.no_grad()
    def get_scores_batched(self, joint_loader, prompt=''):
        """Computes the LM scores for each caption in the joint loader.

        Args:
            joint_loader (DataLoader): batches have "image_options" and "caption_options" fields.
            "image_options" is a list of images, and "caption_options" is a list of captions.

        Returns:
            all_scores: A numpy array containing the scores of the shape NxKxL,
            where N is the number of test cases, K is the number of image options per the test case,
            and L is the number of caption options per the test case.
        """
        self.prompt = prompt
        print(f"Using prompt: {self.prompt}")
        encoded_src = self.tokenizer(self.prompt, return_tensors='pt')
        self.src_tokens = encoded_src['input_ids'].to(self.device)
        self.src_mask = encoded_src['attention_mask'].to(self.device)
        t2i_scores, i2t_scores = [], []
        for batch in tqdm(joint_loader):
            num_image_options = len(batch["image_options"])
            num_test = batch["image_options"][0].shape[0]
            
            s_i2t, s_t2i = self.run_scores_batched(num_test, num_image_options, batch["caption_options"])
            t2i_scores.append(s_t2i)
            i2t_scores.append(s_i2t)

        t2i_scores = np.concatenate(t2i_scores, axis=0) # N x N_t x N_i
        t2i_scores = np.transpose(t2i_scores, (0, 2, 1)) # N x N_i x N_t
        i2t_scores = np.concatenate(i2t_scores, axis=0) # N x N_i x N_t
        print(t2i_scores.shape, i2t_scores.shape)
        return t2i_scores, i2t_scores