import functools
import json
import logging
import re

from typing import Any, Dict, List

# from symbolic_compiler.retriever import retrieve_fn_dict, setup_bm25
from symbolic_compiler.llm_interface.large_language_model import LargeLanguageModel
from symbolic_compiler.data_structs.parameter import Parameter
from symbolic_compiler.data_structs.sentence import Sentence

logger = logging.getLogger("global_logger")

class NER:
    def __init__(self, dsl, label_list, llm: LargeLanguageModel, temperature, freq_penalty, max_tokens, llm_cache_dir):
        # self.retrieve_fn = retrieve_fn_dict[retrieve_fn]
        self.dsl = dsl
        self.label_list = label_list
        self.llm = llm

        self.temperature = temperature
        self.freq_penalty = freq_penalty
        self.max_tokens = max_tokens
        self.llm_cache_dir = llm_cache_dir

        with open("data/ner_prompt.txt", "r") as f:
            self.NER_prompt = f.read()

        if self.dsl == "autodsl":
            self.cases = """Text: Combine 20-100 μg of [total RNA], to generate [a 250 μL reaction].\nAnswer: [{"total RNA": "reagent"}, {"a 250 μL reaction": "output"}]\nText: Add [the biotinylation reaction] to [the coulmn].\nAnswer: [{"the biotinylation reaction": "reagent"}, {"the coulmn": "output"}]"""

        # self.train_examples = []
        # if retrieve_fn == "bm25":
        #     bm25 = setup_bm25(self.train_examples)
        #     self.retrieve_fn = functools.partial(self.retrieve_fn, batch_size=batch_size, bm25=bm25)
        # else:
        #     self.retrieve_fn = functools.partial(self.retrieve_fn, batch_size=batch_size).
        # TODO: Self-Improving for Zero-Shot Named Entity Recognition with Large Language Models https://github.com/Emma1066/Self-Improve-Zero-Shot-NER/tree/main
        # Problem: how to embed, retrive(vector database or memory)
        # Preliminarily: BERT, vector database

    def recognition(self, x: Sentence) -> List[Parameter]:
        text = x.text
        for entity in x.objects:
            text = text.replace(entity, "["+entity+"]", 1)
        prompt = self.NER_prompt[:].replace("[label_set]", str(sorted(self.label_list))).replace("[cases]", self.cases).replace("[query]", text)
        results = self.llm.sample_completions(prompt=prompt, temperature=self.temperature, freq_penalty=self.freq_penalty, max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, num_completions=4)
        return_format = []
        for result in results:
            try:
                return_format = []
                data = json.loads(result.response_text)
                wrong_format = False
                entities_extract = []
                for pair in data:
                    entity = list(pair.keys())[0]
                    label = list(pair.values())[0]
                    param = Parameter(label, entity)
                    return_format.append(param)
                    entities_extract.append(entity)
                    if label not in self.label_list:
                        wrong_format = True
                        break
                if len(entities_extract) != len(set(entities_extract)):
                    wrong_format = True
                if wrong_format:
                    logger.info("wrong json format: " + result.response_text)
                    continue
                return return_format
            except:
                logger.info("wrong json format: "+ result.response_text)
        return None