import os
import pickle as pkl
from openai import OpenAI
import pathlib
import fcntl
import time
import json
import re

MODEL_NAME = [
    'Pro/deepseek-ai/DeepSeek-V3',
    'deepseek-ai/DeepSeek-V3',
    'deepseek-chat',
]

LM = [
    "no_lm",
    "test_lm",
    "THUDM/glm-4-9b-chat",
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    "Qwen/Qwen2.5-7B-Instruct",
]


URL = ''
API_KEY = ''


SYS_PATH = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
class PromptFormat:

    def format_prompt(self, prefix,text_state,suffix):
        with open(prefix, 'r') as f:
            prefix = f.read()
        full_prompt = prefix + f"\n{text_state}\n{suffix}"
        return full_prompt


    def get_system_prompt(self,system_prompt):
        with open(system_prompt, 'r') as f:
            self.system_prompt = f.read()
        return self.system_prompt

    def parse_response(self, response):
        pattern = r'(?:-|\d+.|\d+)\s*(.*?)(?=\s*:|\s*$|\n)'
        matches = re.findall(pattern, response)
        result = []
        result.extend(matches)
        if result==[]:
            raise ValueError(f"Response ********{response} ********** is not a valid response to a bullet prompt")
        return result

class LanguageModel:
    def __init__(self,
                 lm_name = MODEL_NAME[0],
                 max_tokens: int = 100,
                 temperature: float = .5,
                 stop_token=['\n\n'],
                 num_env: int = 512,
                 key_i: int = 0,
                 env_name: str = "Craftax-Classic-Pixels-v1",
                 alg_name: str = 'PPO'):
        assert 0 <= temperature <= 1, f"invalid temperature {temperature}; must be in [0, 1]"
        self.lm = lm_name
        self.url = URL
        self.key_i = key_i
        self.api_key = API_KEY
        self.env_name = env_name
        self.alg_name = 'ELLM'# alg_name
        # self.verbose = verbose  # Whether to print debug information.
        self.prompt_format = PromptFormat()
        self.api_querie_times = 0
        self.update_count = 0
        self.lm_data_num = 0
        self.update_goal_num = 200
        self.update_goal = self.update_goal_num
        self.last_inst = ""
        self.update_threshold = num_env * 100
        self.max_tokens = max_tokens  # 100
        self.temperature = temperature  # 0.7
        self.tokens_per_word = 4 / 3
        self.system_prompt_1 = self.prompt_format.get_system_prompt(
            system_prompt= SYS_PATH / "prompt/system_prompt_for_goal.txt")  # get goal
        self.stop = stop_token
        if "/" in self.lm:
            self.lm_name = self.lm.split("/")[-1]
        else:
            self.lm_name = self.lm
        self.cache_path = SYS_PATH / 'lm_cache' / f'{env_name}+{alg_name}+DeepSeek-V3.pkl'
        self.cache = self.load_cache()
        self.data_path = SYS_PATH / "lm_data"

    def load_cache(self):
        if self.cache_path.exists():
            try:
                with open(self.cache_path, 'rb') as f:
                    fcntl.flock(f, fcntl.LOCK_EX)
                    cache = pkl.load(f)
                    fcntl.flock(f, fcntl.LOCK_UN)
            except Exception as e:
                print(f'Error loading cache: {e}')
                cache = {}
        else:
            cache = {}
            with open(self.cache_path, 'wb') as f:
                pkl.dump({}, f)
        return cache

    def save_cache(self):
        with open(self.cache_path, 'wb') as f:
            # Lock file while saving cache so multiple processes don't overwrite it.
            fcntl.flock(f, fcntl.LOCK_EX)
            pkl.dump(self.cache, f)
            fcntl.flock(f, fcntl.LOCK_UN)

    def load_and_save_cache(self):
        new_cache = self.load_cache()
        # Combine existing and new cache
        self.cache = {**new_cache, **{k: v for k, v in self.cache.items() if k not in new_cache}}
        self.save_cache()

    def store_in_cache(self, inputs, response):
        self.cache[inputs] = response
        self.update_count += 1
        if self.update_count >= self.update_threshold:
            self.load_and_save_cache()
            self.update_count = 0

    def check_in_cache(self, inputs):
        return inputs in self.cache

    def retrieve_from_cache(self, inputs):
        return self.cache[inputs]

    def try_query(self, user_prompt,system_prompt,lm=None):
        if lm == "test_lm":
            # print("Testing LM")
            return  "eat_cow,collect_drink,wake_up"
        response = None # Used to store the API respones, initialized as None.
        attempts = 0 # Record the number of retries.
        max_attempts = 3
        while response is None and attempts < max_attempts:
            try:
                self.api_querie_times += 1
                client = OpenAI(
                    api_key=self.api_key,
                    base_url=self.url
                )
                response = client.chat.completions.create(
                    model=self.lm,
                    messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                    ],
                    temperature=0.5,
                    max_tokens=100,
                )
                response = response.choices[0].message.content
            except Exception as e:
                print(f"Attempt {attempts}/{max_attempts} failed: {e}")
                time.sleep(0.01) # Wait before retrying
            attempts +=1
        if response is None:
            response = "Any goal has been reached."
        return response

    def predict_options(self, text_state, l_score):
        prefix_pt = f"\n{text_state}\n.Past sub-goals:{self.last_inst}."
        adapted_prompt = None

        inputs = (self.lm, text_state, self.max_tokens, self.temperature, tuple(self.stop))
        if self.check_in_cache(inputs):
            goal_text = self.retrieve_from_cache(inputs)
        else:  # If there is no response in the cache, a request needs to be sent to the language model.
            if self.update_goal == self.update_goal_num:
                self.update_goal = 0
                if l_score > 0: # adarefiner
                    llm_sys_pt = """You are a professional game analyst. A player is playing a 2D Minecraft game. You will get the player's observation, status information, and its comprehension score of language guidance. You will be asked to provide concise summaries and suggestions about this player."""
                    llm_user_pt = prefix_pt + f"Player's comprehension score:{l_score:.3f}."
                    adapted_prompt = self.try_query(llm_sys_pt, llm_user_pt,self.lm)
                    prefix_pt = prefix_pt + f"Analysis: <{adapted_prompt}>."
                    # print(f"Analysis: {adapted_prompt}")
                user_prompt = prefix_pt + "Besed on the provided information, suggest 3 sub-goals that the player should accomplish next. No explanation.\n"
                goal_text = self.try_query(user_prompt, self.system_prompt_1,self.lm)
                if self.api_querie_times % 100 == 0:
                    print(f"Analysis: {adapted_prompt}")
                    print(f"Query times: {self.api_querie_times}\nState:{text_state}\nGoal: {goal_text}")

                # Save prompt and goal
                data = {
                    "number": self.lm_data_num,
                    "text_state": text_state,
                    "response": goal_text,
                }
                self.lm_data_num += 1
                with open(f'{self.data_path}/prompt_response+{self.env_name}+{self.alg_name}_DeepSeek-V3.jsonl', 'a') as f:
                    f.write(json.dumps(data) + '\n')
                if not goal_text.strip() == "Any Goal has been reached.":
                    self.store_in_cache(inputs, goal_text)
                if goal_text.strip() == "Any Goal has been reached.":
                    print("Failed to generate goals. Using default goals.")
            else:
                goal_text = self.last_inst
                self.update_goal += 1
        self.last_inst = goal_text
        return goal_text
