import re
import transformers
import torch
from openai import OpenAI


from utils.promptsg1 import distiller, \
    problem_solver, problem_solver_prompt, general_code_template, \
    evaluator, evaluator_prompt, \
    eval_distiller, eval_distiller_prompt, \
    final_evaluator, final_evaluator_prompt
from utils import execute_py_code, peel_py_mk_wrapper, extract_field
from utils.bot_templates import bot_template_game24, bot_template_checkmate, bot_template_word_sorting
from utils.chroma import ChromaVdb

import multiprocessing
import time

def compress_text(text, char_limit=500):
    # If the text length is within the limit, return it as-is
    if len(text) <= char_limit:
        return text
    
    # Calculate how many characters were truncated
    truncated_len = len(text) - char_limit
    
    # Create the compressed text with a hint of how many characters were truncated
    compressed_text = text[:char_limit] + f"...(+{truncated_len} chars)"
    
    return compressed_text

def run_with_timeout(func, *args, timeout_duration=3, **kwargs):
    status = "success"

    def wrapper_func(queue, *args, **kwargs):
        # Run the target function and put the result in the queue
        try:
            queue.put(func(*args, **kwargs))
        except Exception as e:
            queue.put(e)

    queue = multiprocessing.Queue()
    process = multiprocessing.Process(target=wrapper_func, args=(queue, *args), kwargs=kwargs)
    process.start()
    process.join(timeout_duration)

    # Check if the process is still alive after the timeout
    if process.is_alive():
        process.terminate()
        process.join()
        status = "timeout"
        result = "timeout"
    else:
        # Retrieve the result from the queue
        result = queue.get()

    return result, status


def remove_function(code_str, func_name):
    # Regex pattern to match a function named 'sol'
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Remove the function using the pattern
    try:
        cleaned_code = re.sub(pattern, "", code_str)
    except Exception as e:
        import pdb;pdb.set_trace()
    
    return cleaned_code.strip()

def remove_after_function(text, func_name):
    # Regex pattern to match the function definition
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Search for the function definition in the text
    match = re.search(pattern, text)
    
    # If the function definition is found, slice the text up to that point
    if match:
        return text[:match.end()]
    else:
        return text

def remove_sol_call_and_after(text):
    # Use regex to match the function call and everything after it
    cleaned_text = re.sub(r'answer\s*=\s*sol\(.*', '', text, flags=re.DOTALL)
    return cleaned_text

def extract_answer_and_sat(text):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    answer_match = re.search(r'Answer:\s*(.*?)\s*(?=SAT:|$)', text, re.DOTALL)
    sat_match = re.search(r'SAT:\s*(.*?)\s*$', text, re.DOTALL)

    # Extract and clean up the results
    answer = answer_match.group(1).strip() if answer_match else ""
    answer = compress_text(answer)
    sat = sat_match.group(1).strip() if sat_match else ""

    return answer, sat

MAXTOK = 1024
TEMP = 0.4

class LLM:
    def __init__(self, model_id, api_key=None, api_base=None):
        self.model_id = model_id

        if api_key is None:
            self.local = True
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model_id,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map = 'auto'
            )
        else:
            self.local = False
            self.client = OpenAI(
                # This is the default and can be omitted
                api_key=api_key,
                base_url=api_base
            )

    def __call__(self, system_prompt, user_prompt):
        messages = []

        if system_prompt:
            messages.append(
                {"role": "system", "content": system_prompt},
            )
        messages.append(
            {"role": "user", "content": user_prompt},
        )

        if not self.local:
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,

                max_tokens=MAXTOK,
                temperature=TEMP,
                #top_p=0.9,
            )
            
            # Extract and return the assistant's response
            respond = response.choices[0].message.content
            return respond
        else:
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                self.pipeline.tokenizer.eos_token_id,
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = self.pipeline(
                prompt,
                max_new_tokens=MAXTOK,
                eos_token_id=terminators,
                do_sample=True,
                temperature=TEMP,
                #top_p=0.9,
            )

            respond = outputs[0]["generated_text"][len(prompt):]
            return respond
        
class Workflow:
    def __init__(self, model_id=None, api_key=None, api_base=None):
        self.model_id = model_id
        self.api_key = api_key
        self.api_base = api_base

        self.llm = LLM(self.model_id, self.api_key, self.api_base)
        # TODO
        #self.vdb = ChromaVdb(
        #    collection_name="solution_templates",
        #    model_id="intfloat/e5-mistral-7b-instruct", api_key="EMPTY", api_base='http://localhost:8001/v1/')
        #print(f"Loaded Documents: {len(self.vdb)}")


    def problem_distillation(self, question):
        return self.llm(distiller, question)
        
    def template_retrieve(self, distilled_question):
        #template = self.vdb.search(distilled_question)
        template = general_code_template
        return template
            
    def reasoner_instantiation(self, question, distilled_question, template, repeats=6):

        combined_input = problem_solver_prompt.format(
            question=question,
            distilled_question=distilled_question,
            template=template)
        #print(f"Reasoning Prompt Len: {len(problem_solver)+len(combined_input)}")

        res_codes = []
        for i in range(repeats):
            temp_solution = self.llm(problem_solver, combined_input)
            #print(f"Reasoning GEN Len: {len(temp_solution)}")
            #print("Solution initialized")
            sol_codes, num_codes = peel_py_mk_wrapper(temp_solution, -1)
            #if num_codes != 2:
            #    import pdb;pdb.set_trace()
            res_codes += sol_codes
        return res_codes
    
    def code_execution(self, sol_code, sat_code):

        start = time.time()
        #print("Start Exec")

        if "from typing" not in sol_code:
            sol_code = f"from typing import *\n{sol_code}"
        if "def sat" in sol_code:
            sat_code = ""

        sol_code = sol_code.replace("List[", "list[")
        sat_code = sat_code.replace("List[", "list[")

        exec_code = f"{sol_code}\n\n{sat_code}\n\n" + """answer = sol()
print("Answer:")
print(answer)
print("SAT:")
print(sat(answer))
"""
        output, status = run_with_timeout(execute_py_code, exec_code, len(sol_code.split('\n')))
        #print(f"Exec time: {time.time()-start}")
        if status == "timeout":
            temp_result = "Execution timeout. There might be infinite loops or recursions."
            temp_code = exec_code
        else:
            temp_result, temp_code = output
        answer, sat = extract_answer_and_sat(temp_result)
        solution_code = sol_code
        MODE = 0 # 0 means success, 1 means sat false, 2 means code err including infinite loops
        if not temp_result or 'An error occurred' in temp_result or status == "timeout":
            MODE = 2
        elif not "True" in sat:
            MODE = 1
        else:
            assert "True" in sat
            MODE = 0
        return MODE, temp_result

    def debugger(self, distilled_question, code, err_msg):
        question = f"User Question:\n{distilled_question}\n\nPrevious Code:\n{code}\n\nError Message of Previous Code:\n{err_msg}"
        return self.llm(eval_distiller, question)

    
    def run(self, question, query=None):
        l = []
        d = {
            'success': 0,
            'non-sat': 0,
            'error': 0,
 
            'num_attempts': 0,
            'corrected': 0,
            'wronged': 0,
            'others': 0,

            'tp': 0,
            'tn': 0,
            'unsure': 0,
            'selected': False
        }

        wrong_selected = False

        for i in range(12):
            res = self.llm('You are a brilliant problem solver, try to solve the user questions correctly and cleverly. Do not do it brute forcely! Give only one most possible trial.', question)
            field = extract_field(res, 'Answer')
            if not field:
                res2 = self.llm('You are a answer formatter. Your goal is to give user a clean formatted result from a unstructred plain text result. The desirable result should be in the format as the following example:\nAnswer:\n7 * 8 - 4 * 8\n\nThe answer should not contain equal sign, and should be only one line. Extract the result from the plain text answer exactly into this format, no more no less. Do not answer the question yourself.', f"Question:\n{question}\nPlain text answer:\n{res}. Give the formatted answer:")
                field = extract_field(res2, 'Answer')

            final_evaluator = """
You are a evaluator who is good at reasoning and math.

Given a question and a solution to this problem, please verify whether this solution solves the problem 100 percent correctly.

DO NOT try to solve or modify the answer.

Evaluate that solution step by step and at last give your evaluation result in two of the following options:

EVAL: CORRECT

or

EVAL: WRONG
"""

            final_evaluator_prompt = """
User's question:
{question}

Problem Solver's solution:
{distilled_answer}

Eval whether this solution 100 percent answer the question correctly?
"""
            final_evaluator_text = final_evaluator_prompt.format(question=question, distilled_answer=field)
            correct = self.llm(final_evaluator, final_evaluator_text)

            cfield = extract_field(correct, 'EVAL')
            if not cfield:
                correct2 = self.llm('You are a answer formatter. Your goal is to give user a clean formatted result from a unstructred plain text result. The desirable result should be in the format as the following example:\nEVAL: CORRECT\n\nThe answer should be only one line. Extract the result from the plain text answer exactly into this format, no more no less. Do not answer the question yourself.', f"Question:\n{final_evaluator_text}\nPlain text answer:\n{correct}. Give the formatted answer:")
                cfield = extract_field(correct2, 'Answer')
            
            cfield = "CORRECT" in cfield or "Correct" in cfield or "correct" in cfield

            try:
                target = eval(field)
                if int(target) == 24:
                    cgt = True
                else:
                    cgt = False
            except Exception as e:
                cgt = False
                d['error'] += 1

            if cgt:
                d['success'] += 1
                if cfield:
                    d['tp'] += 1
                    d['selected'] = True
            
            if not cfield and not cgt:
                d['tn'] += 1
            
            if cfield and not cgt:
                print("FP")
                if not d['selected']:
                    print("wrong selected")
                    wrong_selected = True
                print(correct)
                d['non-sat'] += 1
            elif not cfield and cgt:
                print("FN")
                print(correct)
            
            d['num_attempts'] += 1
            d['others'] += 1

        if wrong_selected:
            d['selected'] = False
        return d
            
        start = time.time()
        distilled_question = self.problem_distillation(question)
        #print(f'distill time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        template = self.template_retrieve(distilled_question)
        #print(f'template retrieval time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        sol_codes = self.reasoner_instantiation(
            question, distilled_question, template)
        #print(f'inst time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        start = time.time()
        last_code = ""
        success_code = ""
        code_msgs = []
        all_errors = ''
        non_sat_codes = []
        for sol_code in sol_codes:
            mode, msg = self.code_execution(sol_code=sol_code, sat_code=sat_code)
            if mode == 0:
                success_code = sol_code
                d['success'] += 1
            elif mode == 1:
                d['non-sat'] += 1
                non_sat_codes.append(sol_code)
            else:
                d['error'] += 1
                code_msgs.append((sol_code, msg))
                all_errors += f'{msg}\n'
            last_code = sol_code
        num_attempts = len(sol_codes)
        if d['success'] ==0:
            for code, msg in code_msgs:
                cur_distilled_question = self.debugger(distilled_question, code, msg)
                sol_codes = self.reasoner_instantiation(
                    question, cur_distilled_question, template, repeats=1)
                for sol_code in sol_codes:
                    mode, msg = self.code_execution(sol_code=sol_code, sat_code=sat_code)
                    if mode == 0:
                        success_code = sol_code
                        d['success'] += 1
                    elif mode == 1:
                        d['non-sat'] += 1
                        non_sat_codes.append(sol_code)
                    else:
                        d['error'] += 1
                        all_errors += f'{msg}\n'
                    last_code = sol_code
                num_attempts += len(sol_codes)
        #print(f'code execution time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        d['input'] = sat_code
        d['output'] = success_code or last_code
        d['num_attempts'] = num_attempts
        d['all_errors'] = all_errors
        d['non_sat_codes'] = non_sat_codes
        d['naive'] = False
        d['selected'] = d['success'] != 0
        return d
