import re
import transformers
import torch
from openai import OpenAI


from utils.prompts2 import distiller_prompt, general_code_template, combined_input_template, \
    code_debug_prompt, instantiation_prompt, code_verification_template, \
    new_template_prompt, new_template_format, \
    result_solution_system, result_solution_user, \
    need_exec_system, need_exec_user
from utils.utils import execute_py_code, peel_py_mk_wrapper
from utils.bot_templates import bot_template_game24, bot_template_checkmate, bot_template_word_sorting
from utils.chroma import ChromaVdb

import multiprocessing
import time

def has_sat_function_call(code_snippet):
    # Define a regex pattern to match 'sat(...)'
    pattern = r'\bsat\s*\(.*\)'
    
    # Search for the pattern in the code snippet
    match = re.search(pattern, code_snippet)
    
    # Return True if the pattern is found, otherwise False
    return bool(match)

def compress_text(text, char_limit=500):
    # If the text length is within the limit, return it as-is
    if len(text) <= char_limit:
        return text
    
    # Calculate how many characters were truncated
    truncated_len = len(text) - char_limit
    
    # Create the compressed text with a hint of how many characters were truncated
    compressed_text = text[:char_limit] + f"...(+{truncated_len} chars)"
    
    return compressed_text

def run_with_timeout(func, *args, timeout_duration=3, **kwargs):
    status = "success"

    def wrapper_func(queue, *args, **kwargs):
        # Run the target function and put the result in the queue
        try:
            queue.put(func(*args, **kwargs))
        except Exception as e:
            queue.put(e)

    queue = multiprocessing.Queue()
    process = multiprocessing.Process(target=wrapper_func, args=(queue, *args), kwargs=kwargs)
    process.start()
    process.join(timeout_duration)

    # Check if the process is still alive after the timeout
    if process.is_alive():
        process.terminate()
        process.join()
        status = "timeout"
        result = "timeout"
    else:
        # Retrieve the result from the queue
        result = queue.get()

    return result, status


def remove_function(code_str, func_name):
    # Regex pattern to match a function named 'func_name'
    pattern = r"def\s+" + re.escape(func_name) + r"\([^)]*\):[\s\S]*?(?=def\s|$)"
    
    # Find the function code to be removed
    removed_code = re.findall(pattern, code_str)
    
    # Remove the function using the pattern
    cleaned_code = re.sub(pattern, "", code_str)
    
    return cleaned_code.strip(), removed_code[0].strip() if removed_code else None

def remove_sol_call_and_after(text):
    # Use regex to match the function call 'sol' and everything after it
    pattern = r'(answer\s*=\s*sol\(.*)'
    
    # Find the part of the text to be removed
    match = re.search(pattern, text, flags=re.DOTALL)
    
    # If there's a match, split the text into cleaned and removed parts
    if match:
        cleaned_text = text[:match.start()].strip()
        removed_text = text[match.start():].strip()
    else:
        cleaned_text = text
        removed_text = None
    
    return cleaned_text, removed_text

def extract_answer_and_sat(text):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    answer_match = re.search(r'Answer:\s*(.*?)\s*(?=SAT:|$)', text, re.DOTALL)
    sat_match = re.search(r'SAT:\s*(.*?)\s*$', text, re.DOTALL)

    # Extract and clean up the results
    answer = answer_match.group(1).strip() if answer_match else ""
    answer = compress_text(answer)
    sat = sat_match.group(1).strip() if sat_match else ""

    return answer, sat

class LLM:
    def __init__(self, model_id, api_key=None, api_base=None):
        self.model_id = model_id

        if api_key is None:
            self.local = True
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model_id,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map = 'auto'
            )
        else:
            self.local = False
            self.client = OpenAI(
                # This is the default and can be omitted
                api_key=api_key,
                base_url=api_base
            )

    def __call__(self, system_prompt, user_prompt):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

        if not self.local:
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages
            )
            
            # Extract and return the assistant's response
            respond = response.choices[0].message.content
            return respond
        else:
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                self.pipeline.tokenizer.eos_token_id,
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = self.pipeline(
                prompt,
                max_new_tokens=2048,
                eos_token_id=terminators,
                do_sample=True,
                temperature=0.4,
                top_p=0.9,
            )

            respond = outputs[0]["generated_text"][len(prompt):]
            return respond
        
class Workflow:
    def __init__(self, model_id=None, api_key=None, api_base=None):
        self.model_id = model_id
        self.api_key = api_key
        self.api_base = api_base

        self.llm = LLM(self.model_id, self.api_key, self.api_base)
        # TODO
        self.vdb = ChromaVdb(
            collection_name="solution_templates",
            model_id="intfloat/e5-mistral-7b-instruct", api_key="EMPTY", api_base='http://localhost:8001/v1/')
        print(f"Loaded Documents: {len(self.vdb)}")


    def problem_distillation(self, question):
        return self.llm(distiller_prompt, question)
        
    def template_retrieve(self, distilled_question):
        #return bot_template_game24
        #return bot_template_checkmate
        #return bot_template_word_sorting
        template = self.vdb.search(distilled_question)
        return template
            
    def reasoner_instantiation(self, question, distilled_question, template):

        combined_input = combined_input_template.format(
            question=question,
            distilled_question=distilled_question,
            template=template)
        temp_solution = self.llm(instantiation_prompt, combined_input)
        print("Solution initialized")

        #temp_result, temp_code = execute_py_code_in_markdown(temp_solution)
        temp_code = peel_py_mk_wrapper(temp_solution)
        start = time.time()
        print("Start Exec")
        output, status = run_with_timeout(execute_py_code, temp_code)
        print(f"Exec time: {time.time()-start}")
        if status == "timeout":
            temp_result = "Execution timeout. There might be infinite loops or recursions."
        else:
            temp_result, temp_code = output
        answer, sat = extract_answer_and_sat(temp_result)
        solution_code, sat_code = remove_function(temp_code, "sat")
        solution_code, call_code = remove_sol_call_and_after(solution_code)

        cnt = 0
        while not sat or not 'True' in sat \
            or not temp_result or 'An error occurred' in temp_result \
            or status == "timeout" \
            or has_sat_function_call(solution_code):

            try:
                if not temp_result or 'An error occurred' in temp_result:
                    print("REPEAT: Error occurred")
                elif not sat or not "True" in sat:
                    print("REPEAT: Answer incorrect")
                elif has_sat_function_call(solution_code):
                    print("REPEAT: Used SAT code")
                    temp_result = "You called `sat` code inside `sol`, never do this!"
                else:
                    print("REPEAT: Timeout")
            except Exception as e:
                import pdb;pdb.set_trace()
            #print(temp_result)

            if cnt > 6:
                print("Error unsolved")
                return answer, sat, temp_code, solution_code, True # failed
            code_verification_input = code_verification_template.format(question=question, code=temp_code, result=temp_result)
            temp_solution = self.llm(code_debug_prompt, code_verification_input)

            temp_code = peel_py_mk_wrapper(temp_solution)
            solution_code, _ = remove_function(temp_code, "sat")
            solution_code, _ = remove_sol_call_and_after(solution_code)
            temp_code = f"{solution_code}\n{sat_code}\n{call_code}"

            start = time.time()
            print("Start Exec")
            output, status = run_with_timeout(execute_py_code, temp_code)
            print(f"Exec time: {time.time()-start}")
            if status == "timeout":
                temp_result = "Execution timeout. There might be infinite loops or recursions."
            else:
                temp_result, temp_code = output
            answer, sat = extract_answer_and_sat(temp_result)
            #solution_code = remove_sol_call_and_after(remove_function(temp_code, "sat"))
            cnt = cnt + 1

        return temp_result, temp_solution, temp_code, solution_code, False # suceeded

    def result_selector(self, distilled_question, result, solution):
        result_solution_user_prompt = result_solution_user.format(distilled_question=distilled_question, result=result, code=solution)
        return self.llm(result_solution_system, result_solution_user_prompt)
    
    def need_update_template(self, distilled_question, solution, template):
        new_template_system = new_template_prompt.format(template=template)
        new_template_user = new_template_format.format(distilled_question=distilled_question, solution=solution)
        resp = self.llm(new_template_system, new_template_user)
        if "yes" in resp.lower():
            return True
        return False

    def run(self, question):
        start = time.time()
        distilled_question = self.problem_distillation(question)
        print(f'distill time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        template = self.template_retrieve(distilled_question)
        print(f'template retrieval time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        result, solution, full_code, code, failed = self.reasoner_instantiation(question, distilled_question, template)
        print(f'inst time: {time.time()-start}')
        #import pdb;pdb.set_trace()
        print(f"Full Code:\n{full_code}")

        start = time.time()
        if not failed:
            need_update = self.need_update_template(distilled_question, solution, template)
            if need_update:
                self.vdb.add([distilled_question], [{"keywords": "solvable by codes;", "template": solution}]) #TODO: not just code
                print(f"New Document Size: {len(self.vdb)}")
        print(f'new template time: {time.time()-start}')

        #start = time.time()
        #final_result = self.result_selector(question, result, code)
        ## evaluate
        #print(f'selector time: {time.time()-start}')

        ## TODO: hardcode
        #final_result = peel_py_mk_wrapper(final_result) or final_result
        final_result = code

        return final_result, failed
