import re
import transformers
import torch
from openai import OpenAI


from utils.prompts3 import distiller_prompt, general_code_template, \
    combined_input_template, instantiation_prompt, generate_debug_distiller
from utils import execute_py_code, peel_py_mk_wrapper
from utils.bot_templates import bot_template_game24, bot_template_checkmate, bot_template_word_sorting
from utils.chroma import ChromaVdb

import multiprocessing
import time

NO_DEBUGGER = True

def compress_text(text, char_limit=500):
    # If the text length is within the limit, return it as-is
    if len(text) <= char_limit:
        return text
    
    # Calculate how many characters were truncated
    truncated_len = len(text) - char_limit
    
    # Create the compressed text with a hint of how many characters were truncated
    compressed_text = text[:char_limit] + f"...(+{truncated_len} chars)"
    
    return compressed_text

def run_with_timeout(func, *args, timeout_duration=3, **kwargs):
    status = "success"

    def wrapper_func(queue, *args, **kwargs):
        # Run the target function and put the result in the queue
        try:
            queue.put(func(*args, **kwargs))
        except Exception as e:
            queue.put(e)

    queue = multiprocessing.Queue()
    process = multiprocessing.Process(target=wrapper_func, args=(queue, *args), kwargs=kwargs)
    process.start()
    process.join(timeout_duration)

    # Check if the process is still alive after the timeout
    if process.is_alive():
        process.terminate()
        process.join()
        status = "timeout"
        result = "timeout"
    else:
        # Retrieve the result from the queue
        result = queue.get()

    return result, status


def remove_function(code_str, func_name):
    # Regex pattern to match a function named 'sol'
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Remove the function using the pattern
    try:
        cleaned_code = re.sub(pattern, "", code_str)
    except Exception as e:
        import pdb;pdb.set_trace()
    
    return cleaned_code.strip()

def remove_after_function(text, func_name):
    # Regex pattern to match the function definition
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Search for the function definition in the text
    match = re.search(pattern, text)
    
    # If the function definition is found, slice the text up to that point
    if match:
        return text[:match.end()]
    else:
        return text

def remove_sol_call_and_after(text):
    # Use regex to match the function call and everything after it
    cleaned_text = re.sub(r'answer\s*=\s*sol\(.*', '', text, flags=re.DOTALL)
    return cleaned_text

def extract_answer_and_sat(text):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    answer_match = re.search(r'Answer:\s*(.*?)\s*(?=SAT:|$)', text, re.DOTALL)
    sat_match = re.search(r'SAT:\s*(.*?)\s*$', text, re.DOTALL)

    # Extract and clean up the results
    answer = answer_match.group(1).strip() if answer_match else ""
    answer = compress_text(answer)
    sat = sat_match.group(1).strip() if sat_match else ""

    return answer, sat

class LLM:
    def __init__(self, model_id, api_key=None, api_base=None):
        self.model_id = model_id

        if api_key is None:
            self.local = True
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model_id,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map = 'auto'
            )
        else:
            self.local = False
            self.client = OpenAI(
                # This is the default and can be omitted
                api_key=api_key,
                base_url=api_base
            )

    def __call__(self, system_prompt, user_prompt):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

        if not self.local:
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,

                max_tokens=1024,
                temperature=0.4,
                top_p=0.9,
            )
            
            # Extract and return the assistant's response
            respond = response.choices[0].message.content
            return respond
        else:
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                self.pipeline.tokenizer.eos_token_id,
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = self.pipeline(
                prompt,
                max_new_tokens=1024,
                eos_token_id=terminators,
                do_sample=True,
                temperature=0.4,
                top_p=0.9,
            )

            respond = outputs[0]["generated_text"][len(prompt):]
            return respond
        
class Workflow:
    def __init__(self, model_id=None, api_key=None, api_base=None):
        self.model_id = model_id
        self.api_key = api_key
        self.api_base = api_base

        self.llm = LLM(self.model_id, self.api_key, self.api_base)
        # TODO
        #self.vdb = ChromaVdb(
        #    collection_name="solution_templates",
        #    model_id="intfloat/e5-mistral-7b-instruct", api_key="EMPTY", api_base='http://localhost:8001/v1/')
        #print(f"Loaded Documents: {len(self.vdb)}")


    def problem_distillation(self, question):
        return self.llm(distiller_prompt, question)
        
    def template_retrieve(self, distilled_question):
        #template = self.vdb.search(distilled_question)
        template = general_code_template
        return template
            
    def reasoner_instantiation(self, question, distilled_question, template, repeats=6):

        combined_input = combined_input_template.format(
            question=question,
            distilled_question=distilled_question,
            template=template)
        #print(f"Reasoning Prompt Len: {len(instantiation_prompt)+len(combined_input)}")

        res_codes = []
        for i in range(repeats):
            temp_solution = self.llm(instantiation_prompt, combined_input)
            #print(f"Reasoning GEN Len: {len(temp_solution)}")
            #print("Solution initialized")
            sol_codes, num_codes = peel_py_mk_wrapper(temp_solution, -1)
            #if num_codes != 2:
            #    import pdb;pdb.set_trace()
            res_codes += sol_codes
        return res_codes
    
    def code_execution(self, sol_code, sat_code):

        start = time.time()
        #print("Start Exec")

        if "from typing" not in sol_code:
            sol_code = f"from typing import *\n{sol_code}"
        if "def sat" in sol_code:
            sat_code = ""

        sol_code = sol_code.replace("List[", "list[")
        sat_code = sat_code.replace("List[", "list[")

        exec_code = f"{sol_code}\n\n{sat_code}\n\n" + """answer = sol()
print("Answer:")
print(answer)
print("SAT:")
print(sat(answer))
"""
        output, status = run_with_timeout(execute_py_code, exec_code, len(sol_code.split('\n')))
        #print(f"Exec time: {time.time()-start}")
        if status == "timeout":
            temp_result = "Execution timeout. There might be infinite loops or recursions."
            temp_code = exec_code
        else:
            temp_result, temp_code = output
        answer, sat = extract_answer_and_sat(temp_result)
        solution_code = sol_code
        MODE = 0 # 0 means success, 1 means sat false, 2 means code err including infinite loops
        if not temp_result or 'An error occurred' in temp_result or status == "timeout":
            MODE = 2
        elif not "True" in sat:
            MODE = 1
        else:
            assert "True" in sat
            MODE = 0
        return MODE, temp_result

    def debugger(self, distilled_question, code, err_msg):
        question = f"User Question:\n{distilled_question}\n\nPrevious Code:\n{code}\n\nError Message of Previous Code:\n{err_msg}"
        return self.llm(generate_debug_distiller, question)

    
    def run(self, question, query=None, target=None):
        sat_code = query
        start = time.time()
        distilled_question = self.problem_distillation(question)
        #print(f'distill time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        template = self.template_retrieve(distilled_question)
        #print(f'template retrieval time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        RETRY = 1
        TRIALS = 20
        LOOPS = int(TRIALS / (RETRY+1))

        start = time.time()
        sol_codes = self.reasoner_instantiation(
            question, distilled_question, template, repeats=LOOPS)
        #print(f'inst time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        last_code = ""
        success_code = ""
        d = {
            'success': 0,
            'non-sat': 0,
            'error': 0,
            "corrected": 0,
            "wronged": 0,
            "others": 0,
            'tp':0,
            'tn':0,
            'fp':0,
            'fn':0,
            'unsure':0,
            }
        code_msgs = []
        all_errors = ''
        non_sat_codes = []

        for sol_code in sol_codes:
            mode, msg = self.code_execution(sol_code=sol_code, sat_code=sat_code)
            if mode == 0:
                success_code = sol_code
                d['success'] += 1
            elif mode == 1:
                d['non-sat'] += 1
                non_sat_codes.append(sol_code)
            else:
                d['error'] += 1
                code_msgs.append((sol_code, msg))
                all_errors += f'{msg}\n'
            last_code = sol_code
        num_attempts = len(sol_codes)

        if d['success'] ==0 and RETRY:
            for code, msg in code_msgs:
                old_msg = msg
                if NO_DEBUGGER:
                    cur_distilled_question = f"User Question:\n{distilled_question}\n\nPrevious Solution:\n{code}\n\nError Messages of Previous Solution:\n{msg}"
                else:
                    cur_distilled_question = self.debugger(distilled_question, code, msg)
                sol_codes = self.reasoner_instantiation(
                    question, cur_distilled_question, template, repeats=1)
                for sol_code in sol_codes:
                    mode, msg = self.code_execution(sol_code=sol_code, sat_code=sat_code)
                    if mode == 0:
                        success_code = sol_code
                        d['success'] += 1
                        d['corrected'] += 1
                    elif mode == 1:
                        d['non-sat'] += 1
                        non_sat_codes.append(sol_code)
                        d['others'] += 1
                    else:
                        d['error'] += 1
                        all_errors += f'{msg}\n'
                        d['others'] += 1
                    last_code = sol_code
                num_attempts += len(sol_codes)
        #print(f'code execution time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        d['input'] = sat_code
        d['output'] = success_code or last_code
        d['num_attempts'] = num_attempts
        d['all_errors'] = all_errors
        d['non_sat_codes'] = non_sat_codes
        d['naive'] = False
        d['selected'] = d['success'] != 0
        d['loop'] = LOOPS
        return d
