import re
import transformers
import torch
from openai import OpenAI
import chess


from utils.promptsc1 import distiller_prompt, general_code_template, \
    combined_input_template, instantiation_prompt, generate_debug_distiller, debug_prompt, \
    answer_selection, answer_selection_prompt
from utils import execute_py_code, peel_py_mk_wrapper
from utils.bot_templates import bot_template_game24, bot_template_checkmate, bot_template_word_sorting
from utils.chroma import ChromaVdb

import multiprocessing
import time

from collections import Counter

HAS_EVAL = False

def get_score(distilled):
    if "YES" in distilled:
        return 1
    if "NO" in distilled:
        return -1
    else:
        return 0

def most_common_ties(lst):
    count = Counter(lst)
    max_count = max(count.values())
    
    # Get all elements that have the maximum count
    most_common_elements = [k for k, v in count.items() if v == max_count]
    
    return most_common_elements

def preprocess_san_input(input_str):
    """
    Preprocess the input string by removing move numbers and periods.
    :param input_str: The input string of chess moves.
    :return: A list of preprocessed chess moves.
    """
    moves_san = input_str.split()
    moves_san = [move for move in moves_san if not move.endswith('.')]

    # Initialize a new chess board
    board = chess.Board()
    
    # Apply the given moves to the board
    for move_san in moves_san:
        try:
            move = board.parse_san(move_san)
            board.push(move)
        except ValueError:
            assert False, "Invalid move in input"
    return board

def compress_text(text, char_limit=500):
    # If the text length is within the limit, return it as-is
    if len(text) <= char_limit:
        return text
    
    # Calculate how many characters were truncated
    truncated_len = len(text) - char_limit
    
    # Create the compressed text with a hint of how many characters were truncated
    compressed_text = text[:char_limit] + f"...(+{truncated_len} chars)"
    
    return compressed_text

def run_with_timeout(func, *args, timeout_duration=3, **kwargs):
    status = "success"

    def wrapper_func(queue, *args, **kwargs):
        # Run the target function and put the result in the queue
        try:
            queue.put(func(*args, **kwargs))
        except Exception as e:
            queue.put(e)

    queue = multiprocessing.Queue()
    process = multiprocessing.Process(target=wrapper_func, args=(queue, *args), kwargs=kwargs)
    process.start()
    process.join(timeout_duration)

    # Check if the process is still alive after the timeout
    if process.is_alive():
        process.terminate()
        process.join()
        status = "timeout"
        result = "timeout"
    else:
        # Retrieve the result from the queue
        result = queue.get()

    return result, status


def remove_function(code_str, func_name):
    # Regex pattern to match a function named 'sol'
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Remove the function using the pattern
    try:
        cleaned_code = re.sub(pattern, "", code_str)
    except Exception as e:
        import pdb;pdb.set_trace()
    
    return cleaned_code.strip()

def remove_after_function(text, func_name):
    # Regex pattern to match the function definition
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Search for the function definition in the text
    match = re.search(pattern, text)
    
    # If the function definition is found, slice the text up to that point
    if match:
        return text[:match.end()]
    else:
        return text

def remove_sol_call_and_after(text):
    # Use regex to match the function call and everything after it
    cleaned_text = re.sub(r'answer\s*=\s*sol\(.*', '', text, flags=re.DOTALL)
    return cleaned_text

def extract_answer_and_sat(text):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    answer_match = re.search(r'Answer:\s*(.*?)\s*(?=SAT:|$)', text, re.DOTALL)
    sat_match = re.search(r'SAT:\s*(.*?)\s*$', text, re.DOTALL)

    # Extract and clean up the results
    answer = answer_match.group(1).strip() if answer_match else ""
    answer = compress_text(answer)
    sat = sat_match.group(1).strip() if sat_match else ""

    return answer, sat

def extract_checkmate_answer(text):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    match = re.search(r'FINAL ANSWER:\s*(.*?)\s*$', text, re.DOTALL)
    # Extract and clean up the results
    answer = match.group(1).strip() if match else ""
    return answer

class LLM:
    def __init__(self, model_id, api_key=None, api_base=None):
        self.model_id = model_id

        if api_key is None:
            self.local = True
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model_id,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map = 'auto'
            )
        else:
            self.local = False
            self.client = OpenAI(
                # This is the default and can be omitted
                api_key=api_key,
                base_url=api_base
            )

    def __call__(self, system_prompt, user_prompt):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

        if not self.local:
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,

                max_tokens=1024,
                temperature=0.4,
                top_p=0.9,
            )
            
            # Extract and return the assistant's response
            respond = response.choices[0].message.content
            return respond
        else:
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                self.pipeline.tokenizer.eos_token_id,
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = self.pipeline(
                prompt,
                max_new_tokens=1024,
                eos_token_id=terminators,
                do_sample=True,
                temperature=0.4,
                top_p=0.9,
            )

            respond = outputs[0]["generated_text"][len(prompt):]
            return respond
        
class Workflow:
    def __init__(self, model_id=None, api_key=None, api_base=None):
        self.model_id = model_id
        self.api_key = api_key
        self.api_base = api_base

        self.llm = LLM(self.model_id, self.api_key, self.api_base)
        # TODO
        #self.vdb = ChromaVdb(
        #    collection_name="solution_templates",
        #    model_id="intfloat/e5-mistral-7b-instruct", api_key="EMPTY", api_base='http://localhost:8001/v1/')
        #print(f"Loaded Documents: {len(self.vdb)}")


    def problem_distillation(self, question):
        return self.llm(distiller_prompt, question)
        
    def template_retrieve(self, distilled_question):
        #template = self.vdb.search(distilled_question)
        template = general_code_template
        return template
            
    def reasoner_instantiation(self, question, distilled_question, template):

        combined_input = combined_input_template.format(
            question=question,
            distilled_question=distilled_question,
            template=template)
        #print(f"Reasoning Prompt Len: {len(instantiation_prompt)+len(combined_input)}")

        temp_solution = self.llm(instantiation_prompt, combined_input)
        result = extract_checkmate_answer(temp_solution)
        return temp_solution, result
    
    def checkmate1_check(self, san, move):
        try:
            board = preprocess_san_input(san)
            move = move.split('.')[-1].strip()
            move = board.parse_san(move)
            board.push(move)
        except:
            return False, False
        return board.is_checkmate(), True
        
    def code_execution(self, sol_code, sat_code):

        start = time.time()
        #print("Start Exec")

        if "from typing" not in sol_code:
            sol_code = f"from typing import *\n{sol_code}"
        if "def sat" in sol_code:
            sat_code = ""

        sol_code = sol_code.replace("List[", "list[")
        sat_code = sat_code.replace("List[", "list[")

        exec_code = f"{sol_code}\n\n{sat_code}\n\n" + """answer = sol()
print("Answer:")
print(answer)
print("SAT:")
print(sat(answer))
"""
        output, status = run_with_timeout(execute_py_code, exec_code, len(sol_code.split('\n')))
        #print(f"Exec time: {time.time()-start}")
        if status == "timeout":
            temp_result = "Execution timeout. There might be infinite loops or recursions."
            temp_code = exec_code
        else:
            temp_result, temp_code = output
        answer, sat = extract_answer_and_sat(temp_result)
        solution_code = sol_code
        MODE = 0 # 0 means success, 1 means sat false, 2 means code err including infinite loops
        if not temp_result or 'An error occurred' in temp_result or status == "timeout":
            MODE = 2
        elif not "True" in sat:
            MODE = 1
        else:
            assert "True" in sat
            MODE = 0
        return MODE, temp_result

    def debugger(self, question, distilled_question, code, err_msg=None):
        question = debug_prompt.format(question=question, distilled_question=distilled_question, code=code)
        return self.llm(generate_debug_distiller, question)

    def result_selection(self, question, candidates):
        answer = self.llm(answer_selection, answer_selection_prompt.format(question=question, candidates='\n'.join(candidates)))
        return extract_checkmate_answer(answer)
        
    def helper_classify(self, gt, predicted, d2):
        if gt and predicted == 1:
            d2['tp'] += 1
        elif gt and predicted == -1:
            d2['fn'] += 1
        if not gt and predicted == 1:
            d2['fp'] += 1
        if not gt and predicted == -1:
            d2['tn'] += 1
        else:
            d2['unsure'] += 1
    
    def run(self, question, query=None, repeats=6):

        first_attempt = self.llm(instantiation_prompt, question)
        res0 = extract_checkmate_answer(first_attempt)
        isCheckmate0, hasNoErr0 = self.checkmate1_check(query, res0)

        start = time.time()
        distilled_question = self.problem_distillation(question)
        #print(f'distill time: {time.time()-start}')
        #import pdb;pdb.set_trace()

        sol_codes = []
        for i in range(repeats):

            start = time.time()
            template = self.template_retrieve(distilled_question)
            #print(f'template retrieval time: {time.time()-start}')
            #import pdb;pdb.set_trace()

            start = time.time()
            sol_codes.append(self.reasoner_instantiation(
                question, distilled_question, template))
            #print(f'inst time: {time.time()-start}')
            #import pdb;pdb.set_trace()

        start = time.time()
        d = {'success': 0, 'non-sat': 0, 'error': 0, 'corrected':0, 'wronged':0, 'others':0}
        d2 = {'fp': 0, 'fn': 0, 'tp': 0, 'tn': 0, 'unsure': 0}
        code_msgs = []
        all_errors = ''
        non_sat = ''

        moves = []
        for sol, res in sol_codes:
            isCheckmate, hasNoErr = self.checkmate1_check(query, res)
            if not HAS_EVAL:
                cur_distilled_question = self.debugger(question, distilled_question, sol)
                score = get_score(cur_distilled_question)
            else:
                cur_distilled_question = ""
                score = -1 # TODO
            self.helper_classify(isCheckmate, score, d2)
            if not hasNoErr:
                d['error'] += 1
                all_errors += f'{res}\n'
            elif isCheckmate:
                d['success'] += 1
            else:
                d['non-sat'] += 1
                non_sat += f'{res}\n'
            moves.append([sol, res, isCheckmate, score, cur_distilled_question])
        num_attempts = len(sol_codes)

        round_debugs = 1
        prev_moves = moves
        for i in range(round_debugs):
            new_moves = []
            for i, (sol, res, prevIsCheckmate, score, curq) in enumerate(prev_moves):
                if not HAS_EVAL:
                    cur_distilled_question = curq
                else:
                    cur_distilled_question = self.debugger(question, distilled_question, sol)
                    
                #print(cur_distilled_question)
                sol, res = self.reasoner_instantiation(
                    question, cur_distilled_question, template)
                isCheckmate, hasNoErr = self.checkmate1_check(query, res)

                if not HAS_EVAL:
                    cur_distilled_question = self.debugger(question, distilled_question, sol)
                    score = get_score(cur_distilled_question)
                else:
                    cur_distilled_question = ""
                    score = -1 # TODO
                self.helper_classify(isCheckmate, score, d2)

                if not hasNoErr:
                    d['error'] += 1
                    all_errors += f'{res}\n'
                elif isCheckmate:
                    d['success'] += 1
                else:
                    d['non-sat'] += 1
                    non_sat += f'{res}\n'
                
                if isCheckmate and not prevIsCheckmate:
                    d['corrected'] += 1
                if not isCheckmate and prevIsCheckmate:
                    d['wronged'] += 1
                else:
                    d['others'] += 1
                
                new_moves.append([sol, res, isCheckmate, score, cur_distilled_question])
                num_attempts += 1
            moves.extend(new_moves)
            prev_moves = new_moves
        #print(f'code execution time: {time.time()-start}')
        all_moves = [res for _, res, _, _, _ in moves+new_moves]
        #moves = most_common_ties(all_moves)
        moves = list(set(all_moves))

        res = self.result_selection(question, moves)

        isCheckmate, hasNoErr = self.checkmate1_check(query, res)

        d['input'] = query
        d['output'] = res
        d['num_attempts'] = num_attempts
        d['all_errors'] = all_errors
        d['non_sat_codes'] = non_sat
        d['naive'] = isCheckmate0 and hasNoErr0
        d['selected'] = isCheckmate and hasNoErr
        d.update(d2)
        return d
