import openai
from openai import OpenAI
import re
import json
import sys
from copy import deepcopy
import os
import importlib
import logging
import os
import datetime
import pdb
import random
from prompt.prompt import get_solution_prompt, code_check_prompt, feedback_fix_prompt, simple_example, puttwo_example, examine_example, clean_example, heat_example, cool_example, get_start_from_prompt

import re
import traceback
RED = "\033[31m"
BLUE = "\033[34m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
RESET = "\033[0m"

prefixes = {
    'pick_and_place': 'put',
    'look_at_obj': 'examine',
    'pick_clean_then_place': 'clean',
    'pick_heat_then_place': 'heat',
    'pick_cool_then_place': 'cool',
    'pick_two_obj_then_place': 'puttwo'
}
task_id2task_name = {i + 1: k for i, k in enumerate(prefixes.keys())}
task_id2task_method_name = {
    1: 'def ' + task_id2task_name[1] + '(self, objecttype, receptacletype)',
    2: 'def ' + task_id2task_name[2] + '(self, objecttype)',
    3: 'def ' + task_id2task_name[3] + '(self, objecttype, receptacletype)',
    4: 'def ' + task_id2task_name[4] + '(self, objecttype, receptacletype)',
    5: 'def ' + task_id2task_name[5] + '(self, objecttype, receptacletype)',
    6: 'def ' + task_id2task_name[6] + '(self, objecttype, receptacletype)',
}
# task1: put a objecttype in receptacletype. / put some objecttype on receptacletype.
# task2: examine the objecttype with the desklamp. / look at objecttype under the desklamp.
# task3: put a clean objecttype in receptacletype. / clean some objecttype and put it in receptacletype.
# task4: heat some objecttype and put it in receptacletype. / put a hot objecttype in receptacletype.
# task5: cool some objecttype and put it in receptacletype. / put a cool objecttype in receptacletype.
# task6: put two objecttype in receptacletype. / find two objecttype and put them in receptacletype.
task_id2task_method_desc = {
    1: 'with given objecttype and receptacletype, pick an object and place it in a receptacle.',
    # 2: 'with given objecttype, pick an object, turn on a desklamp to examine the object.',
    2: 'with given objecttype, pick an object, go to a desklamp and turn on it to examine the object.',
    3: 'with given objecttype and receptacletype, pick an object, go to a sinkbasin and clean it, then place it in a receptacle.',
    4: 'with given objecttype and receptacletype, pick an object, go to a microwave and heat it, and place it in a receptacle.',
    5: 'with given objecttype and receptacletype, pick an object, go to a fridge and cool it, and place it in a receptacle.',
    6: 'with given objecttype and receptacletype, \n# 1. pick an object and record its name by `self.old_object_name`, and go to a receptacle then place it in the receptacle. \n# 2. pick another object that `object.name != self.old_object_name`, and go to a receptacle then place it in the receptacle.',
    # 1. pick an object and record its name by `self.old_object_name`, and place it in a receptacle, then, pick another object and place it in the receptacle.
}



class llm2planner:
    def __init__(
        self,
        model_name, 
        temperature,
        num_repeat_sample=3, 
        time_stamp=None,
        task_id=None,
    ) -> None:
        # assert task_id == 1
        self.model_name = model_name
        self.temperature = temperature
        self.num_repeat_sample = num_repeat_sample
        self.time_stamp = time_stamp
        
        task_name = task_id2task_name[task_id]
        if task_name.startswith('pick_two_obj'):
            example = puttwo_example
            # example = examine_example
            example_method_name = "self.look_at_two_obj(self, objecttype)"
        elif task_name.startswith('look_at_obj'):
            example = examine_example
            example_method_name = "self.pick_and_place(self, objecttype, receptacletype)"
            # example = simple_example
        elif task_name.startswith('pick_and_place'):
            example = simple_example
            example_method_name = "self.pick_object(self, objecttype)"
        elif task_name.startswith('pick_clean_then_place'):
            # example = clean_example
            example = examine_example
            example_method_name = "self.pick_and_place(self, objecttype, receptacletype)"
        elif task_name.startswith('pick_heat_then_place'):
            # example = heat_example
            example = examine_example
            example_method_name = "self.pick_and_place(self, objecttype, receptacletype)"
        elif task_name.startswith('pick_cool_then_place'):
            # example = cool_example
            example = examine_example
            example_method_name = "self.pick_and_place(self, objecttype, receptacletype)"
            
        self.prompt = get_solution_prompt\
            .replace('<task>', task_name)\
            .replace('<task_method>', task_id2task_method_name[task_id])\
            .replace('<task_method_desc>', task_id2task_method_desc[task_id])\
            .replace('<example>', example)\
            .replace('<example_method_name>', example_method_name)
            
        self.planners_path = os.path.join(
            f'./planners/{model_name}/task_{task_id}/',
            time_stamp if time_stamp is not None else ""
        )
        self.planners_path = self.planners_path + "/" if not self.planners_path.endswith("/") else self.planners_path
        if not os.path.exists(self.planners_path):
            os.makedirs(self.planners_path)
        
        
        self.success_planners_path = os.path.join(
            self.planners_path,
            'success_planners.json'
        )
            
        self.evolve_text = feedback_fix_prompt\
            .replace('<task>', task_name)\
            .replace('<task_method>', task_id2task_method_name[task_id])\
            .replace('<task_method_desc>', task_id2task_method_desc[task_id])\
            .replace('<example>', example)\
            .replace('<example_method_name>', example_method_name)
        
        # core text for each planner
        self.core_text = [None for _ in range(self.num_repeat_sample)]
        self.llm_calls = {
            "times": 0,
            "error_planner": []
        }
        self.error_times_threshold = 5
        
        # pdb.set_trace()
        if "gpt" in model_name:
            self.client = OpenAI(api_key=os.environ.get("openai_key"))
        elif "deepseek" in model_name:
            self.client = OpenAI(api_key=os.environ.get("deepseek_key"), base_url="https://api.deepseek.com")
        else:
            # open-source llm served in local
            self.client = OpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1")
            
    def warm_up(self):
        py_files = [item for item in os.listdir(self.planners_path) if item.endswith('.py')]
        added_num = self.num_repeat_sample - len(py_files)
        if added_num <= 0:
            return

        print(YELLOW + "Warm up the planner..." + RESET)
        # print(self.text)
        for idx in range(added_num):
            round_message, _, planner_path = self.generate_planner(idx, self.prompt)
        
        return round_message, planner_path
    
    def store_llm_calls(self):
        # print("Store llm calls...")
        with open(os.path.join(self.planners_path, 'llm_calls.json'), 'w', encoding="utf-8") as file:
            json.dump(self.llm_calls, file, indent=4)
    
    def regenerate_planner(self, error_planner=None, error_idx=None):
        if error_planner is None:
            assert error_idx is not None
        else:
            assert error_idx is None
        py_files = [item for item in os.listdir(self.planners_path) if item.endswith('.py')]
        py_files = sorted(py_files, key=lambda x: int(x.split('_')[1].split('.')[0]))
        py_files = py_files[-self.num_repeat_sample:]
        regenerate_flag = False 
        if error_planner is not None:
            error_planner_num = len([v for k, v in error_planner.items() if v >= self.error_times_threshold])
            if error_planner_num <= self.num_repeat_sample // 2:
                return regenerate_flag
            for k, v in error_planner.items():
                if v >= self.error_times_threshold:
                    # regenerate the planner
                    regenerate_flag = True
                    print(YELLOW + f"Regenerate the {py_files[k]}..." + RESET)
                    _, planner_text, _ = self.generate_planner(k, self.prompt, False)
                    # sotre planner_text manually
                    with open(os.path.join(self.planners_path, py_files[k]), 'w', encoding="utf-8") as file:
                        file.write(planner_text)
                    self.llm_calls["error_planner"].append(py_files[k])
        
        if error_idx is not None:
            regenerate_flag = True
            _, planner_text, _ = self.generate_planner(error_idx, self.prompt, False)
            # sotre planner_text manually
            with open(os.path.join(self.planners_path, py_files[error_idx]), 'w', encoding="utf-8") as file:
                file.write(planner_text)
            self.llm_calls["error_planner"].append(py_files[error_idx])
        
        if regenerate_flag:
            self.store_llm_calls()
        
        return regenerate_flag
            
            
    def generate_planner(self, idx, text=None, is_store=True):
        if text is None:
            text = self.prompt
        round_message = []
        round_message.append(
            {
                "role": "system", 
                "content": "You are an agent who plays the rogue-like game NetHack 3.6.6 using a `Python` program."
            }
        )
        round_message.append(
            {
                "role": "user",
                "content": text
            }
        )
        success = False
        # pdb.set_trace()
        planner_path = None
        while not success:
            try:
                # planner_answer = openai.chat.completions.create(
                planner_answer = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=round_message,
                    temperature=self.temperature
                )
                planner_text = self.parse_planner(planner_answer)
                planner_text = self.fill_planner(idx, planner_text)
                
                if is_store:
                    planner_path = self.store_planner(planner_text, None) # do not store the planner content: planner_answer.choices[0].message.content
            except Exception as e:
                logging.error(f"Error when generate planner: {e}")
                print(RED + traceback.format_exc() + RESET)
                print(planner_answer.choices[0].message.content)
                continue
            
            round_message.append(
                {
                    "role": "assistant",
                    "content": planner_answer.choices[0].message.content
                }
            )
            success = True
            
        self.llm_calls["times"] += 1
        self.store_llm_calls()
        
        return round_message, planner_text, planner_path
    
    def fill_planner(self, idx, text):
        
        self.core_text[idx] = deepcopy(text)
        
        return text
        
    def fill_evolveText(self, results):
        evolve_text = self.evolve_text
        # fill old/reference python programs
        for idx in range(self.num_repeat_sample):
            evolve_text = evolve_text.replace(f"<python_program_{idx}>", self.core_text[idx])
        
        # fill the interaction results
        evolve_text = evolve_text.replace("<interaction_history>", results)
        # pdb.set_trace()
        return evolve_text
        
    def evolve_planner(self, results):
        print(YELLOW + "Evolve the planner..." + RESET)
        evolve_text = self.fill_evolveText(results)
        # print(evolve_text)
        # pdb.set_trace()
        for idx in range(self.num_repeat_sample):
            round_message, _, planner_path = self.generate_planner(idx, evolve_text)
        return round_message, planner_path
    
    def parse_planner(self, planner_answer):
        planner_content = planner_answer.choices[0].message.content
        '''
            planner_content is a str like:
            
            ...
            ```python
                planner program
            ```
            ...
            
            we need to parse the planner program from the planner_content
        '''
        # pdb.set_trace()
        if self.model_name == "Meta-Llama-3-8B-Instruct":
            # find ``` as start_idx
            python_flag = True
            planners_start_idx = planner_content.find('```Python')
            if planners_start_idx == -1:
                planners_start_idx = planner_content.find('```python')
                if planners_start_idx == -1:
                    planners_start_idx = planner_content.find('```')
                    if planners_start_idx == -1:
                        raise ValueError("Cannot find the planner program.")
                    python_flag = False
            
            if python_flag:
                planner_content_copy = planner_content[planners_start_idx + 9:]
            else:
                planner_content_copy = planner_content[planners_start_idx + 3:]
            # find ``` as end_idx
            planners_end_idx = planner_content_copy.find('```')
            if planners_end_idx == -1:
                raise ValueError("Cannot find the planner program.")
            planner = planner_content_copy[:planners_end_idx].strip()
            return planner
        
        planners = re.findall(r'```python(.*?)```', planner_content, re.DOTALL)
        planner = planners[0].strip()
        
        return planner
    
    def store_planner(self, planner, planner_content=None):
        # pdb.set_trace()
        idxs = [int(item.split('_')[1].split('.')[0]) for item in os.listdir(self.planners_path) if 'planner' in item and item.endswith('.py')]
        idxs = sorted(idxs)
        idx = idxs[-1] + 1 if len(idxs) > 0 else 0
        planner_path = self.planners_path + f'planner_{idx}.py'
        with open(planner_path, 'w', encoding="utf-8") as file:
            file.write(planner)
        
        if planner_content is not None:
            with open(planner_path.replace('.py', '.md'), 'w', encoding="utf-8") as file:
                file.write(planner_content)
        return planner_path
    
    def store_success_planner_into_json(self):
        print("Store success planners into json...")
        idxs = [int(item.split('_')[1].split('.')[0]) for item in os.listdir(self.planners_path) if 'planner' in item and item.endswith('.py')]
        idxs = sorted(idxs)
        
        if not os.path.exists(self.success_planners_path):
            success_planners = {}
        else:
            with open(self.success_planners_path, 'r', encoding="utf-8") as file:
                success_planners = json.load(file)
        
        current_time_stamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
        temp_success_planners = {}
        for i, idx in enumerate(idxs[-self.num_repeat_sample:]):
            planner_path = self.planners_path + f'planner_{idx}.py'
            with open(planner_path, 'r', encoding="utf-8") as file:
                planner = file.read()
            temp_success_planners[f"planner_{i}_{idx}"] = planner
        
        if success_planners:
            # we do not need to store the same success planners
            # find the last time stamp for current success_planners
            last_time_stamp = sorted(success_planners.keys())[-1]
            if success_planners[last_time_stamp] == temp_success_planners:
                return
        success_planners[current_time_stamp] = temp_success_planners
        
        # store success_planners
        with open(self.success_planners_path, 'w', encoding="utf-8") as file:
            json.dump(success_planners, file, indent=4)
        
        return 
    
if __name__ == "__main__":
    import datetime
    # test llm2planner
    api_key=os.environ.get("openai_key", None)
    openai.api_key = api_key
    
    _llm2planner = llm2planner(
        model_name="gpt-3.5-turbo", # "gpt-4-1106-preview", "gpt-4o"
        temperature=0.6,
        num_repeat_sample=3,
        time_stamp=datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
        # task="altar_buc"
        task_id=1
    )
    _llm2planner.warm_up()