import os
import pickle
import logging

import numpy as np
import wandb
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import json
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='deberta', help='model name')
parser.add_argument('--seed_str', type=str, default='')
args = parser.parse_args()
model_name = args.model
seed_str = args.seed_str
if seed_str == 'None':
    seed_str = ''

DEVICE = 'cuda'
class EntailmentDeberta():
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xxlarge-mnli")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "microsoft/deberta-v2-xxlarge-mnli",
            device_map="cuda"
        )

    def check_implication(self, text1, text2, *args, **kwargs):
        inputs = self.tokenizer(text1, text2, return_tensors="pt",
                                padding='longest', truncation=True, max_length=256).to(DEVICE)
        # The model checks if text1 -> text2, i.e. if text2 follows from text1.
        # check_implication('The weather is good', 'The weather is good and I like you') --> 1
        # check_implication('The weather is good and I like you', 'The weather is good') --> 2
        # print(inputs['input_ids'].shape)
        outputs = self.model(**inputs)
        logits = outputs.logits
        # Deberta-mnli returns `neutral` and `entailment` classes at indices 1 and 2.
        largest_index = torch.argmax(F.softmax(logits, dim=-1), -1)  # pylint: disable=no-member
        prediction = largest_index.cpu().numpy()
        return prediction

model = EntailmentDeberta()
save_dir = ''
base_dir = ''
for noise_level in [1, 2, 3, 4, 5]:
    file_name = f'{model_name}_gaussian_noise_{noise_level}_ensemble_3{seed_str}__rejection.json'
    log = json.load(open(os.path.join(base_dir, file_name), 'r'))
    scores = []
    for i in tqdm(range(len(log))):
        outputs = log[i]['extra_info'].split('- ')[1:]
        pairs = [
            [outputs[0], outputs[0], outputs[1]],
            [outputs[1], outputs[2], outputs[2]],
        ]
        score = (
            model.check_implication(pairs[0], pairs[1]) + model.check_implication(pairs[1], pairs[0])
        ) / 2
        scores.append(score)
    save_name = f'{model_name}_gaussian_noise_{noise_level}_ensemble_3{seed_str}_diversity_score.npy'
    np.save(os.path.join(save_dir, save_name), np.array(scores))