import re
import transformers
import torch
from openai import OpenAI


from utils import distiller_prompt, general_code_template, \
    combined_input_template, instantiation_prompt, \
    combined_verification_template, verification_prompt, \
    combined_verification_template2, verification_prompt2, \
    code_debug_prompt, code_debug_template, \
    code_error_prompt, code_error_template, \
    new_template_prompt, new_template_format, \
    result_solution_system, result_solution_user, \
    need_exec_system, need_exec_user
from utils import execute_py_code, peel_py_mk_wrapper
from utils.bot_templates import bot_template_game24, bot_template_checkmate, bot_template_word_sorting
from utils.chroma import ChromaVdb

import multiprocessing
import time

def compress_text(text, char_limit=500):
    # If the text length is within the limit, return it as-is
    if len(text) <= char_limit:
        return text
    
    # Calculate how many characters were truncated
    truncated_len = len(text) - char_limit
    
    # Create the compressed text with a hint of how many characters were truncated
    compressed_text = text[:char_limit] + f"...(+{truncated_len} chars)"
    
    return compressed_text

def run_with_timeout(func, *args, timeout_duration=3, **kwargs):
    status = "success"

    def wrapper_func(queue, *args, **kwargs):
        # Run the target function and put the result in the queue
        try:
            queue.put(func(*args, **kwargs))
        except Exception as e:
            queue.put(e)

    queue = multiprocessing.Queue()
    process = multiprocessing.Process(target=wrapper_func, args=(queue, *args), kwargs=kwargs)
    process.start()
    process.join(timeout_duration)

    # Check if the process is still alive after the timeout
    if process.is_alive():
        process.terminate()
        process.join()
        status = "timeout"
        result = "timeout"
    else:
        # Retrieve the result from the queue
        result = queue.get()

    return result, status


def remove_function(code_str, func_name):
    # Regex pattern to match a function named 'sol'
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Remove the function using the pattern
    try:
        cleaned_code = re.sub(pattern, "", code_str)
    except Exception as e:
        import pdb;pdb.set_trace()
    
    return cleaned_code.strip()

def remove_after_function(text, func_name):
    # Regex pattern to match the function definition
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Search for the function definition in the text
    match = re.search(pattern, text)
    
    # If the function definition is found, slice the text up to that point
    if match:
        return text[:match.end()]
    else:
        return text

def remove_sol_call_and_after(text):
    # Use regex to match the function call and everything after it
    cleaned_text = re.sub(r'answer\s*=\s*sol\(.*', '', text, flags=re.DOTALL)
    return cleaned_text

def extract_answer_and_sat(text):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    answer_match = re.search(r'Answer:\s*(.*?)\s*(?=SAT:|$)', text, re.DOTALL)
    sat_match = re.search(r'SAT:\s*(.*?)\s*$', text, re.DOTALL)

    # Extract and clean up the results
    answer = answer_match.group(1).strip() if answer_match else ""
    answer = compress_text(answer)
    sat = sat_match.group(1).strip() if sat_match else ""

    return answer, sat

class LLM:
    def __init__(self, model_id, api_key=None, api_base=None):
        self.model_id = model_id

        if api_key is None:
            self.local = True
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model_id,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map = 'auto'
            )
        else:
            self.local = False
            self.client = OpenAI(
                # This is the default and can be omitted
                api_key=api_key,
                base_url=api_base
            )

    def __call__(self, system_prompt, user_prompt):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

        if not self.local:
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages
            )
            
            # Extract and return the assistant's response
            respond = response.choices[0].message.content
            return respond
        else:
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                self.pipeline.tokenizer.eos_token_id,
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = self.pipeline(
                prompt,
                max_new_tokens=1024,
                eos_token_id=terminators,
                do_sample=True,
                temperature=0.4,
                top_p=0.9,
            )

            respond = outputs[0]["generated_text"][len(prompt):]
            return respond
        
class Workflow:
    def __init__(self, model_id=None, api_key=None, api_base=None):
        self.model_id = model_id
        self.api_key = api_key
        self.api_base = api_base

        self.llm = LLM(self.model_id, self.api_key, self.api_base)
        # TODO
        self.vdb = ChromaVdb(
            collection_name="solution_templates",
            model_id="intfloat/e5-mistral-7b-instruct", api_key="EMPTY", api_base='http://localhost:8001/v1/')
        print(f"Loaded Documents: {len(self.vdb)}")


    def problem_distillation(self, question):
        return self.llm(distiller_prompt, question)
        
    def template_retrieve(self, distilled_question):
        #return bot_template_game24
        #return bot_template_checkmate
        #return bot_template_word_sorting
        template = self.vdb.search(distilled_question)
        return template
            
    def reasoner_verification(self, question, distilled_question):

        combined_verification = combined_verification_template.format(
            question=question,
            distilled_question=distilled_question,
        )
        print(f"SAT Prompt Len: {len(verification_prompt)+len(combined_verification)}")
        temp_solution = self.llm(verification_prompt, combined_verification)
        temp_solution0 = temp_solution
        print(f"SAT Gen Len: {len(temp_solution)}")
        sat_code = peel_py_mk_wrapper(temp_solution)
        return "", temp_solution, "", sat_code, False # suceeded

        # Now I don't check the code for the SAT step
        combined_verification2 = combined_verification_template2.format(
            question=question,
            distilled_question=distilled_question,
            sat_code=sat_code
        )
        print(f"SAT Prompt Len2: {len(verification_prompt2)+len(combined_verification2)}")
        temp_solution = self.llm(verification_prompt2, combined_verification2)
        print(f"SAT Gen Len2: {len(temp_solution)}")
        temp_code = peel_py_mk_wrapper(temp_solution)
        print("Verification initialized")

        start = time.time()
        print("Start Verification Exec")
        output, status = run_with_timeout(execute_py_code, temp_code)
        print(f"Verifyication Exec time: {time.time()-start}")
        if status == "timeout":
            temp_result = "Execution timeout. There might be infinite loops or recursions."
        else:
            temp_result, temp_code = output

        cnt = 0
        while not temp_result or 'An error occurred' in temp_result \
            or status == "timeout":
            #import pdb;pdb.set_trace() # Check why sat get error

            try:
                if not temp_result or 'An error occurred' in temp_result:
                    print("REPEAT: Error occurred")
                else:
                    print("REPEAT: Timeout")
            except Exception as e:
                import pdb;pdb.set_trace()

            if cnt > 3:
                print("Error unsolved")
                return temp_code, sat_code, True # failed
            code_error_input = code_error_template.format(code=temp_code, result=temp_result)
            temp_solution = self.llm(code_error_prompt, code_error_input)
            temp_code = peel_py_mk_wrapper(temp_solution)
            start = time.time()
            print("Start Verification Exec")
            output, status = run_with_timeout(execute_py_code, temp_code)
            print(f"Verification Exec time: {time.time()-start}")

            if status == "timeout":
                temp_result = "Execution timeout. There might be infinite loops or recursions."
            else:
                temp_result, temp_code = output
            sat_code = remove_after_function(temp_code, "sat")
            cnt = cnt + 1

        return temp_result, temp_solution, temp_code, sat_code, False # suceeded

    def reasoner_instantiation(self, question, distilled_question, template, sat_code, sat_failed):

        combined_input = combined_input_template.format(
            question=question,
            distilled_question=distilled_question,
            template=template)
        print(f"Reasoning Prompt Len: {len(instantiation_prompt)+len(combined_input)}")
        temp_solution = self.llm(instantiation_prompt, combined_input)
        print(f"Reasoning GEN Len: {len(temp_solution)}")
        print("Solution initialized")
        sol_code = peel_py_mk_wrapper(temp_solution)

        start = time.time()
        print("Start Exec")
        if sat_failed:
            exec_code = f"{sol_code}\n\n" + """answer = sol()
print("Answer:")
print(answer)
print("SAT:")
False
"""
        else:
            exec_code = f"{sol_code}\n\n{sat_code}\n\n" + """answer = sol()
print("Answer:")
print(answer)
print("SAT:")
print(sat(answer))
"""
        output, status = run_with_timeout(execute_py_code, exec_code, len(sol_code.split('\n')))
        print(f"Exec time: {time.time()-start}")
        if status == "timeout":
            temp_result = "Execution timeout. There might be infinite loops or recursions."
            temp_code = exec_code
        else:
            temp_result, temp_code = output
        answer, sat = extract_answer_and_sat(temp_result)
        solution_code = sol_code
        if sat_failed:
            return temp_result, temp_solution, temp_code, solution_code, True # failed

        cnt = 0
        while not sat or not 'True' in sat \
            or not temp_result or 'An error occurred' in temp_result \
            or status == "timeout":

            try:
                if not temp_result or 'An error occurred' in temp_result:
                    print("REPEAT: Error occurred")
                elif not sat or not "True" in sat:
                    print("REPEAT: Answer incorrect")
                else:
                    print("REPEAT: Timeout")
            except Exception as e:
                import pdb;pdb.set_trace()
            #print(temp_result)

            if cnt > 6:
                print("Error unsolved")
                return answer, sat, temp_code, solution_code, True # failed
            code_verification_input = code_debug_template.format(question=question, code=sol_code, result=temp_result)
            temp_solution = self.llm(code_debug_prompt, code_verification_input)
            sol_code = peel_py_mk_wrapper(temp_solution)
            exec_code = f"{sol_code}\n\n{sat_code}\n\n" + """answer = sol()
print("Answer:")
print(answer)
print("SAT:")
print(sat(answer))
"""
            start = time.time()
            print("Start Exec")
            output, status = run_with_timeout(execute_py_code, exec_code, len(sol_code.split('\n')))
            print(f"Exec time: {time.time()-start}")
            if status == "timeout":
                temp_result = "Execution timeout. There might be infinite loops or recursions."
                temp_code = exec_code
            else:
                temp_result, temp_code = output
            answer, sat = extract_answer_and_sat(temp_result)
            solution_code = sol_code
            cnt = cnt + 1

        return temp_result, temp_solution, temp_code, solution_code, False # suceeded

    def result_selector(self, distilled_question, result, solution):
        result_solution_user_prompt = result_solution_user.format(distilled_question=distilled_question, result=result, code=solution)
        return self.llm(result_solution_system, result_solution_user_prompt)
    
    def need_update_template(self, distilled_question, solution, template):
        new_template_system = new_template_prompt.format(template=template)
        new_template_user = new_template_format.format(distilled_question=distilled_question, solution=solution)
        resp = self.llm(new_template_system, new_template_user)
        if "yes" in resp.lower():
            return True
        return False

    def run(self, question):
        start = time.time()
        distilled_question = self.problem_distillation(question)
        print(f'distill time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        template = self.template_retrieve(distilled_question)
        print(f'template retrieval time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        sat_result, sat_solution, _, sat_code, sat_failed = self.reasoner_verification(question, distilled_question)
        print(f'sat time: {time.time()-start}')

        start = time.time()
        result, solution, full_code, code, failed = self.reasoner_instantiation(
            question, distilled_question, template, sat_code, sat_failed)
        print(f'inst time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        if not failed:
            need_update = self.need_update_template(distilled_question, solution, template)
            if need_update:
                self.vdb.add([distilled_question], [{"keywords": "solvable by codes;", "template": solution}]) #TODO: not just code
                print(f"New Document Size: {len(self.vdb)}")
        print(f'new template time: {time.time()-start}')

        start = time.time()
        final_result = self.result_selector(question, result, code)
        final_result = remove_sol_call_and_after(remove_function(final_result, "sat"))
        print(f'selector time: {time.time()-start}')

        print(sat_code)
        return final_result, failed
