import os.path
import numpy as np, random
from query_strategy import SentenceEncoder
import torch
from os.path import join


class qmsum_bench():

    def __init__(self, dataset, mode):
        self.dataset = dataset
        self.cur_idx = 0
        self.mode = mode
        self.encoder = SentenceEncoder()

    def judge(self, gt: str, answer: str):
        target_embedding = self.encoder.encode(gt)
        answer_embedding = self.encoder.encode(answer)
        cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
        return cos(target_embedding, answer_embedding)

    def extract_ans(self, ans):
        ans_line = ans.split('answer is ')
        # Expect to see 'answer is'. If not return whole string
        if len(ans_line) == 1:
            return ans
        else:
            ans = ans_line[-1].strip()

        if self.mode == 'multiple_choice':
            options = ['(A)', '(B)', '(C)', '(D)', '(E)', '(F)', '(G)', '(H)', '(I)', '(J)', '(K)', '(L)', '(M)', '(N)',
                       '(O)', '(P)', '(Q)', '(R)', '(S)', '(T)', '(U)', '(V)', '(W)', '(X)', '(Y)', '(Z)']
            for option in options:
                if option in ans:
                    ans = option[1]
                    break
            return ans
        elif self.mode == 'free_form':
            if ans[-1] == '.':
                ans = ans[:-1]
            return ans


    def eval(self, question, answer):
        pass


    def getData(self):
        return self.dataset['qa_list'][self.cur_idx]


    def __len__(self):
        return len(self.dataset)


    def __iter__(self):
        return self


    def __next__(self):
        res = self.dataset['qa_list'][self.cur_idx]
        self.cur_idx += 1
        if self.cur_idx < len(self.dataset):
            return res
        else:
            self.cur_idx = 0
            raise StopIteration()

import json


class qmsum_interface(qmsum_bench):

    def __init__(self, task, task_name, mask_rate = 0.5):
        json_name = join('./data', task, task_name)
        with open( json_name, "r", encoding="utf-8") as f:
            context = json.load(f)
        mode = "free_form"

        super(qmsum_interface, self).__init__(context, mode)

