import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


class OptModelWrapper:
    def __init__(self, root_dir, device, variant="opt-2.7b"):
        self.max_length = 35
        
        self.variant = variant
        self.root_dir = root_dir
        self.tokenizer = AutoTokenizer.from_pretrained(f"facebook/{variant}", use_fast=False)
        self.model = AutoModelForCausalLM.from_pretrained(f"facebook/{variant}", torch_dtype=torch.float16).cuda()
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        self.device = device
    
    
    @torch.no_grad()
    def get_bartscore(self, text):
        """
        Args: text: list of text
        returns: Tensor

        """
        encoded_tgt = self.tokenizer(
            [self.prompt + text[i] for i in range(len(text))],
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors='pt'
        )
        tgt_tokens = encoded_tgt['input_ids'].to(self.device)
        tgt_mask = encoded_tgt['attention_mask'].to(self.device)
        
        src_len = int(self.src_mask.sum(dim=1).to(self.device))
        assert src_len == 1
        tgt_len = tgt_mask.sum(dim=1) - src_len
        
        decoder_targets = tgt_tokens.masked_fill(tgt_tokens == self.model.config.pad_token_id, -100)         
        decoder_targets[:,:src_len] = -100
        
        # concatenate the src_tokens and tgt_tokens for input_ids, attention_mask, and label
        output = self.model(
            input_ids=tgt_tokens,
            attention_mask=tgt_mask,
            labels=decoder_targets
        )
        
        shifted_tgt_tokens = tgt_tokens.masked_fill(tgt_tokens == self.model.config.pad_token_id, -100)
        shifted_tgt_tokens = shifted_tgt_tokens[:, src_len:].contiguous()
        
        logits = output.logits[:, src_len-1:-1].contiguous().view(-1, self.model.config.vocab_size)
        loss = self.loss_fct(logits, shifted_tgt_tokens.view(-1))
        loss = loss.view(tgt_tokens.shape[0], -1)
        loss = loss.sum(dim=1) / tgt_len
        
        # loss_fct_mean = torch.nn.CrossEntropyLoss(reduction="mean") 
        # loss = torch.zeros(output.logits.shape[0])
        # for k in range(loss.shape[0]):
        #     loss[k] = loss_fct_mean(output.logits[k, src_len:], shifted_tgt_tokens[k])
        return (-loss).exp()
    
    def run_scores_batched(self, num_test, num_image_options, text):
        # Should return something with shape (n_tests, n_image_options, n_text_options)
        # Image embeds and all: (n_tests, n_image_options, embed_dim)
        # Text embeds and all: (n_tests, n_text_options, embed_dim)
        
        # Score matrix should be of the size: (n_tests, n_image_options, n_text_options)
        num_text_options = len(text)
        score_matrix_i2t = torch.full((num_test, 1, num_text_options),-100.0).to(self.device)
    
        for i in range(num_test): 
            lm_score = self.get_bartscore([t[i] for t in text])
            score_matrix_i2t[i,0] = lm_score
    
        score_matrix_i2t = score_matrix_i2t.repeat(1, num_image_options, 1)

        score_matrix_t2i = score_matrix_i2t.permute(0,2,1)
        return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
        
        
    @torch.no_grad()
    def get_scores_batched(self, joint_loader, prompt=''):
        """Computes the LM scores for each caption in the joint loader.

        Args:
            joint_loader (DataLoader): batches have "image_options" and "caption_options" fields.
            "image_options" is a list of images, and "caption_options" is a list of captions.

        Returns:
            all_scores: A numpy array containing the scores of the shape NxKxL,
            where N is the number of test cases, K is the number of image options per the test case,
            and L is the number of caption options per the test case.
        """
        self.prompt = prompt
        print(f"Using prompt: {self.prompt}")
        # if self.prompt[:-1] == " ":
        #     print("Warning: prompt end with space. Opt tokenizer will treat the space as a token when at the end of sentence.")
        #     print(f"Removing the space from prompt: \"{self.prompt[:-1]}\"")
        #     self.prompt = self.prompt[:-1]
        #     assert self.prompt[-1] != " "
        if self.prompt != "":
            print(f"We don't support prompt for OPT model.")
            exit(0)
        encoded_src = self.tokenizer(self.prompt, return_tensors='pt')
        self.src_tokens = encoded_src['input_ids'].to(self.device)
        self.src_mask = encoded_src['attention_mask'].to(self.device)
        t2i_scores, i2t_scores = [], []
        for batch in tqdm(joint_loader):
            num_image_options = len(batch["image_options"])
            num_test = batch["image_options"][0].shape[0]
            
            s_i2t, s_t2i = self.run_scores_batched(num_test, num_image_options, batch["caption_options"])
            t2i_scores.append(s_t2i)
            i2t_scores.append(s_i2t)

        t2i_scores = np.concatenate(t2i_scores, axis=0) # N x N_t x N_i
        t2i_scores = np.transpose(t2i_scores, (0, 2, 1)) # N x N_i x N_t
        i2t_scores = np.concatenate(i2t_scores, axis=0) # N x N_i x N_t
        print(t2i_scores.shape, i2t_scores.shape)
        return t2i_scores, i2t_scores