from collections import defaultdict
from importlib import reload
from typing import List, Union

import ipdb
import numpy as np
import torch
import tqdm
import transformers

import models
import utils


class ClassifyWrapper():

    def __init__(self, model_name='microsoft/deberta-large-mnli', device='cuda:3') -> None:
        self.model_name = model_name
        self.model, self.tokenizer = models.load_model_and_tokenizer(model_name, device)

        pass

    @torch.no_grad()
    def _batch_pred(self, sen_1: list, sen_2: list):
        inputs = [_[0] + ' [SEP] ' + _[1] for _ in zip(sen_1, sen_2)]
        inputs = self.tokenizer(inputs, padding=True, truncation=True)
        logits = self.model(torch.tensor(inputs['input_ids']).to(self.model.device),
                            attention_mask=torch.tensor(inputs['attention_mask']).to(self.model.device))['logits']
        return logits

    @torch.no_grad()
    def _pred(self, sen_1: str, sen_2: str):
        input = sen_1 + ' [SEP] ' + sen_2
        #input = f"[CLS] {input} [SEP]"
        #input = '[CLS] I love you. [SEP] I like you. [SEP]'
        input_ids = self.tokenizer.encode(input, return_tensors='pt').to(self.model.device)

        logits = self.model(input_ids)['logits']
        # logits: [Contradiction, neutral, entailment]
        return logits

    @torch.no_grad()
    def pred_qa(self, question:str, ans_1:str, ans_2:str):
        return self._pred(f"{question} {ans_1}", f'{question} {ans_2}')

    @torch.no_grad()
    def _compare(self, question:str, ans_1:str, ans_2:str):
        pred_1 = self._pred(f"{question} {ans_1}", f'{question} {ans_2}')
        pred_2 = self._pred(f"{question} {ans_2}", f'{question} {ans_1}')
        preds = torch.concat([pred_1, pred_2], 0)

        deberta_prediction = 0 if preds.argmax(1).min() == 0 else 1
        return {'deberta_prediction': deberta_prediction,
                'prob': torch.softmax(preds,1).mean(0).cpu(),
                'pred': preds.cpu()
                }

if __name__ == '__main__':
    obj = ClassifyWrapper()