import json
import re
import gensim
import logging

from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
from nltk import pos_tag

from llm_compiler.utils.loader import load_llm
from llm_compiler.datatype import DataType
from llm_compiler.corpora import Corpora_Feature

logger = logging.getLogger("global_logger")

class llm_compiler:
    def __init__(self, dataset, dsl_name, engine, temperature, freq_penalty, max_tokens, llm_cache_dir) -> None:
        self.dataset = dataset
        self.dsl_name = dsl_name
        self.engine = engine
        self.temperature = temperature
        self.freq_penalty = freq_penalty
        self.max_tokens = max_tokens
        self.llm_cache_dir = llm_cache_dir

        self.llm = load_llm(self.engine)
        if self.dsl_name == 'autodsl':
            with open("data/" + dsl_name + "/" + dataset + ".json", 'r') as f:
                self.dsl = json.load(f)

        self.lemmatizer = WordNetLemmatizer()
        self.word2vec_model = gensim.models.KeyedVectors.load_word2vec_format('../GoogleNews-vectors-negative300.bin.gz', binary=True)
        self.datatype = DataType(dataset)
        self.entity_extraction = Corpora_Feature(100, self.datatype, self.llm, temperature=self.temperature, freq_penalty=self.freq_penalty, max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir)
        with open("data/operation_extraction.txt") as file:
            self.operation_extraction_prompt = file.read()
        with open("data/emit_extraction.txt") as file:
            self.emit_extraction_prompt = file.read()
        self.label_mapping = {
            "REG": "reagent",
            "Container":"container",
            "Device":"device",
            "Time":"time",
            "Temperature":"temperature",
            "Mass":"mass",
            "Speed":"speed",
            "Concentration":"concentration",
            "Volume":"volume",
            "Length":"length",
            "String":"string",
            "Force":"force",
            "Bool":"bool",
            "Voltage":"voltage",
            "Frequency":"frequency"
        }
        

    def compile(self, x:str) -> str:
        x = re.sub(r'\s+', ' ',x.replace("\n", " ")).strip()
        sentense_list = [a.strip() for a in x.split(". ") if len(a.strip()) > 10]
        result = []
        for sentense in sentense_list:
            logger.info(sentense)
            operation = self.__operation_extraction(sentense)
            if "NONE" in operation:
                result.append({"action": "", "output": ""})
                continue
            opcode = self.__similarity_opcode(operation)
            if "NONE" in opcode:
                result.append({"action": "", "output": ""})
                continue
            entity_list = self.__entity_extraction(sentense, self.datatype)
            slot = self.similarity_params(opcode, entity_list)
            emit = self.__get_emits(sentense)

            json_like_result = {"action": opcode.lower(), "output": emit.lower()}
            for param in slot:
                property = self.label_mapping[param[0]]
                value = param[1]
                if isinstance(value, str):
                    value = value.lower()
                if property not in json_like_result:
                    json_like_result[property] = [value]
                else:
                    json_like_result[property].append(value)
            result.append(json_like_result)
        return "\n".join([json.dumps(x) for x in result])
    
    def __operation_extraction(self, sentense):
        prompt = self.operation_extraction_prompt.replace("------", sentense)
        result = self.llm.sample_completions(prompt, temperature=self.temperature, freq_penalty=self.freq_penalty, max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, num_completions=1)
        result = result[0].response_text.strip()
        if "NONE" in result.upper():
            return "NONE"
        words = word_tokenize(result)
        tagged_words = pos_tag(words)
        if len(tagged_words) == 1:
            return tagged_words[0][0].upper()
        return "NONE"
    
    def __similarity_opcode(self, operation):
        closest_word = None
        operation = operation.upper()
        max_similarity = -1
        if operation not in self.word2vec_model:
            return "NONE"
        for target_word in self.dsl:
            if target_word in self.word2vec_model:
                similarity = self.word2vec_model.similarity(operation.lower(), self.lemmatizer.lemmatize(target_word.lower(), pos='v'))
                if similarity > max_similarity:
                    max_similarity = similarity
                    closest_word = target_word
        return closest_word
    
    def __entity_extraction(self, sentense, datatype):
        origin, label = self.entity_extraction.data_annotate(sentense)
        result = []
        for ori, lab in zip(origin, label):
            if lab in datatype.type:
                result.append([lab, ori])
        return result
    
    def similarity_params(self, opcode, entity_list):
        pattern = [x["pattern"] for x in self.dsl[opcode]]
        max_similarity = -1
        max_pattern = []
        matched_pattern = []
        matched_result = []
        for p in pattern:
            sim, lcs, result = self.__lcs_similarity(p, entity_list)
            if sim > max_similarity or (sim == max_similarity and len(p) > len(max_pattern)):
                max_similarity = sim
                max_pattern = p
                matched_pattern = lcs
                matched_result = result

        if len(matched_pattern) == len(entity_list):
            return matched_result
        else:
            return entity_list
        
    def __get_emits(self, sentense):
        prompt = self.emit_extraction_prompt.replace("+-+-+-", sentense)
        result = self.llm.sample_completions(prompt, temperature=self.temperature, freq_penalty=self.freq_penalty, max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, num_completions=1)
        result = result[0].response_text.strip()
        words = word_tokenize(result)
        if len(words) < 50:
            return result
        else:
            return ""
    
    def __lcs_similarity(self, list1, list2):
        m = len(list1)
        n = len(list2)

        dp = [[0] * (n + 1) for _ in range(m + 1)]
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if list1[i - 1] == list2[j - 1][0]:
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
        lcs = []
        result = []
        i, j = m, n
        while i > 0 and j > 0:
            if list1[i - 1] == list2[j - 1][0]:
                lcs.append(list2[j - 1])
                result.append(list2[j - 1])
                i -= 1
                j -= 1
            elif dp[i - 1][j] > dp[i][j - 1]:
                result.append([list1[i-1], None])
                i -= 1
            else:
                j -= 1
        while i > 0:
            result.append([list1[i-1], None])
            i -= 1
        lcs.reverse()
        result.reverse()
        return len(lcs) , lcs, result