import string
import numpy as np
import os

from concurrent.futures import ThreadPoolExecutor
# import logging

# from tqdm import tqdm
# from factscore.abstain_detection import is_response_abstained
from retrieval import DocDB, Retrieval

class Scorer(object):

    def __init__(self,
                 client,
                 config,
                 model_name="retrieval+ChatGPT",
                 batch_size=256):
        assert model_name in ["retrieval+llama", "retrieval+llama+npm", "retrieval+ChatGPT", "npm", "retrieval+ChatGPT+npm", "retrieval"]
        self.model_name = model_name
        self.client = client
        self.config = config

        self.data_dir = config.model.annotator.data_path
        self.cache_dir = config.model.annotator.retrieval_cache_path

        self.db = {}
        self.retrieval = {}
        self.npm = {}
        self.batch_size = batch_size # batch size for retrieval
        # self.abstain_detection_type = abstain_detection_type

        # self.data_dir = data_dir
        # self.cache_dir = cache_dir
        # if not os.path.exists(cache_dir):
        #     os.makedirs(cache_dir)

        self.af_generator = None

    def save_cache(self):
        self.client.save_cache()
        if "npm" in self.model_name:
            for k, v in self.npm.items():
                v.save_cache()
        for k, v in self.retrieval.items():
            v.save_cache()
        for k, v in self.db:
            v.save_cache()

    def register_knowledge_source(self, name="enwiki-20230401", db_path=None, data_path=None):
        assert name not in self.retrieval, f"{name} already registered"
        if db_path is None:
            db_path = os.path.join(self.data_dir, f"{name}.db")

        if data_path is None:
            data_path = os.path.join(self.data_dir, f"{name}.jsonl")

        db_cache_path = os.path.join(self.cache_dir, f"db-{name}.pkl")
        cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.json")
        embed_cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.pkl")

        self.db[name] = DocDB(db_path=db_path, data_path=data_path, cache_path=db_cache_path)
        self.retrieval[name] = Retrieval(self.db[name], cache_path, embed_cache_path, retrieval_type="bm25", batch_size=self.batch_size)
        # if "npm" in self.model_name:
        #     cache_path = os.path.join(self.cache_dir, f"bm25-{name}.json")
        #     embed_cache_path = os.path.join(self.cache_dir, f"bm25-{name}.pkl")
        #     self.npm[name] = NPM(Retrieval(self.db[name], cache_path, embed_cache_path, "bm25"),
        #                          "npm-single",
        #                          cache_file=os.path.join(self.cache_dir, f"npm-{name}.pkl"))


    def get_score(self,
                  topics,
                  generations,
                  atomic_facts,
                  gamma=10,
                  knowledge_source=None):
        if knowledge_source is None:
            # use the default knowledge source
            knowledge_source = "enwiki-20230401"

        if knowledge_source not in self.retrieval:
            self.register_knowledge_source(knowledge_source)

        if type(topics)==type(generations)==str:
            topics = [topics]
            generations = [generations]
            atomic_facts = [atomic_facts]
        else:
            assert type(topics)==type(generations)==list, "`topics` and `generations` should be lists."
            assert len(topics)==len(generations), "`topics` and `generations` should have the same length"
            assert len(topics)==len(atomic_facts), "`topics` and `atomic_facts` should have the same length"

        respond_ratio = np.mean([facts is not None for facts in atomic_facts])

        scores = []
        init_scores = []
        decisions = []
        for topic, generation, facts in zip(topics, generations, atomic_facts):
            if facts is None:
                decisions.append(None)
            else:
                decision = []
                for fact in facts:
                    decision.append(
                        self._get_score(topic, generation, fact, knowledge_source, decision)
                    )
                score = np.mean([d["is_supported"] for d in decision])
                
                if gamma:
                    init_scores.append(score)
                    penalty = 1.0 if len(facts)>gamma else np.exp(1-gamma/len(facts))
                    score = penalty * score
                
                decisions.append(decision)
                scores.append(score)
                # if len(scores) % 10 == 0:
                #     self.save_cache()

        out = {"score": np.mean(scores),
               "respond_ratio": respond_ratio,
               "decisions": decisions,
               "num_facts_per_response": np.mean([len(d) for d in decisions if d is not None])}

        if gamma:
            out["init_score"] = np.mean(init_scores)
        
        return out

    def _get_score(self, topic, generation, atom, knowledge_source, prev_decisions = []):
        atom = atom.strip()
        passages = self.retrieval[knowledge_source].get_passages(topic, atom, k=5)
        definition = f"Answer the question about {topic} based on the given context and your previous answers.\n\n"
        context = ""
        for psg in reversed(passages):
            context += "Title: {}\nText: {}\n\n".format(psg["title"], psg["text"].replace("<s>", "").replace("</s>", ""))
        definition += context.strip()
        if not definition[-1] in string.punctuation:
            definition += "."
        prompt = f"{definition.strip()}\n\n"
        for prev_decision in prev_decisions:
            prev_score = "True" if prev_decision["is_supported"] else "False"
            prompt += f"Previous input: {prev_decision['atom']} True or False? Output: {prev_score}"
        
        prompt += f"Input: {atom.strip()} True or False?\nOutput:"

        # output = [{'message': 'blah blah blah'}]
        output = self.client.query(prompt) 

        # if type(output[1])==np.ndarray:
        #     # when logits are available
        #     logits = np.array(output[1])
        #     assert logits.shape[0] in [32000, 32001]
        #     true_score = logits[5852]
        #     false_score = logits[7700]
        #     is_supported = true_score > false_score
        # else:
        # when logits are unavailable
        generated_answer = output[0]['message'].lower()
        if "true" in generated_answer or "false" in generated_answer:
            if "true" in generated_answer and "false" not in generated_answer:
                is_supported = True
            elif "false" in generated_answer and "true" not in generated_answer:
                is_supported = False
            else:
                is_supported = generated_answer.index("true") > generated_answer.index("false")
        else:
            is_supported = all([keyword not in generated_answer.lower().translate(str.maketrans("", "", string.punctuation)).split() for keyword in ["not", "cannot", "unknown", "information"]])

        if is_supported and "npm" in self.model_name:
            npprob = self.npm[knowledge_source].get_probabilty(topic, atom)
            is_supported = npprob > 0.3

        decision = {"atom": atom, "is_supported": is_supported}
        
        # self.client.cache_outputs(
        #     [prompt],
        #     np.zeros((1,), dtype=int),
        #     [output]
        # )

        return decision