import re
import transformers
import torch
from openai import OpenAI
import chess


from utils.promptsc2 import distiller_prompt, general_code_template, \
    combined_input_template, instantiation_prompt, generate_debug_distiller, debug_prompt, \
    answer_selection, answer_selection_prompt
from utils import execute_py_code, peel_py_mk_wrapper, extract_field
from utils.bot_templates import bot_template_game24, bot_template_checkmate, bot_template_word_sorting
from utils.chroma import ChromaVdb

import multiprocessing
import time

from collections import Counter

HAS_EVAL = False

def get3moves(query):
    board = preprocess_san_input(query)
    check_move = ''
    valid_moves = [board.san(move) for move in board.legal_moves]
    for move in board.legal_moves:
        board_copy = board.copy()
        board_copy.push(move)
        
        # Check if the move results in checkmate
        if board_copy.is_checkmate():
            check_move = board.san(move) 

    valid_move = valid_moves[0]
    if valid_move == check_move:
        valid_move = valid_moves[1]

    for i in range(8):
        invalid_move = replace_number(valid_move, i)
        if invalid_move not in valid_moves:
            break
    
    return check_move, valid_move, invalid_move

def replace_number(s, n):
    # Function to subtract 1 from a matched number
    def subtract_one(match):
        num = int(match.group())  # Extract the number and convert to int
        #return str(num - 1)  # Subtract 1 and convert back to string
        return str(n)  # Subtract 1 and convert back to string
    
    # Replace only the first occurrence of a number
    return re.sub(r'\d+', subtract_one, s, count=1)

def get_score(distilled):
    if "YES" in distilled:
        return 1
    if "NO" in distilled:
        return -1
    else:
        return 0

def most_common_ties(lst):
    count = Counter(lst)
    max_count = max(count.values())
    
    # Get all elements that have the maximum count
    most_common_elements = [k for k, v in count.items() if v == max_count]
    
    return most_common_elements

def preprocess_san_input(input_str):
    """
    Preprocess the input string by removing move numbers and periods.
    :param input_str: The input string of chess moves.
    :return: A list of preprocessed chess moves.
    """
    moves_san = input_str.split()
    moves_san = [move for move in moves_san if not move.endswith('.')]

    # Initialize a new chess board
    board = chess.Board()
    
    # Apply the given moves to the board
    for move_san in moves_san:
        try:
            move = board.parse_san(move_san)
            board.push(move)
        except ValueError:
            assert False, "Invalid move in input"
    return board

def compress_text(text, char_limit=500):
    # If the text length is within the limit, return it as-is
    if len(text) <= char_limit:
        return text
    
    # Calculate how many characters were truncated
    truncated_len = len(text) - char_limit
    
    # Create the compressed text with a hint of how many characters were truncated
    compressed_text = text[:char_limit] + f"...(+{truncated_len} chars)"
    
    return compressed_text

def run_with_timeout(func, *args, timeout_duration=3, **kwargs):
    status = "success"

    def wrapper_func(queue, *args, **kwargs):
        # Run the target function and put the result in the queue
        try:
            queue.put(func(*args, **kwargs))
        except Exception as e:
            queue.put(e)

    queue = multiprocessing.Queue()
    process = multiprocessing.Process(target=wrapper_func, args=(queue, *args), kwargs=kwargs)
    process.start()
    process.join(timeout_duration)

    # Check if the process is still alive after the timeout
    if process.is_alive():
        process.terminate()
        process.join()
        status = "timeout"
        result = "timeout"
    else:
        # Retrieve the result from the queue
        result = queue.get()

    return result, status


def remove_function(code_str, func_name):
    # Regex pattern to match a function named 'sol'
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Remove the function using the pattern
    try:
        cleaned_code = re.sub(pattern, "", code_str)
    except Exception as e:
        import pdb;pdb.set_trace()
    
    return cleaned_code.strip()

def remove_after_function(text, func_name):
    # Regex pattern to match the function definition
    pattern = r"def\s+"+func_name+r"\([^)]*\):[\s\S]*?(?=\n\S)"
    
    # Search for the function definition in the text
    match = re.search(pattern, text)
    
    # If the function definition is found, slice the text up to that point
    if match:
        return text[:match.end()]
    else:
        return text

def remove_sol_call_and_after(text):
    # Use regex to match the function call and everything after it
    cleaned_text = re.sub(r'answer\s*=\s*sol\(.*', '', text, flags=re.DOTALL)
    return cleaned_text

def extract_answer_and_sat(text):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    answer_match = re.search(r'Answer:\s*(.*?)\s*(?=SAT:|$)', text, re.DOTALL)
    sat_match = re.search(r'SAT:\s*(.*?)\s*$', text, re.DOTALL)

    # Extract and clean up the results
    answer = answer_match.group(1).strip() if answer_match else ""
    answer = compress_text(answer)
    sat = sat_match.group(1).strip() if sat_match else ""

    return answer, sat

def extract_checkmate_answer(text):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    match = re.search(r'FINAL ANSWER:\s*(.*?)\s*$', text, re.DOTALL)
    # Extract and clean up the results
    answer = match.group(1).strip() if match else ""
    return answer

class LLM:
    def __init__(self, model_id, api_key=None, api_base=None):
        self.max_tokens = 3000
        self.model_id = model_id

        if api_key is None:
            self.local = True
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model_id,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map = 'auto'
            )
        else:
            self.local = False
            self.client = OpenAI(
                # This is the default and can be omitted
                api_key=api_key,
                base_url=api_base
            )

    def __call__(self, system_prompt, user_prompt):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

        if not self.local:
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,

                max_tokens=self.max_tokens, # 1024
                temperature=0.4,
                top_p=0.9,
            )
            
            # Extract and return the assistant's response
            respond = response.choices[0].message.content
            return respond
        else:
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                self.pipeline.tokenizer.eos_token_id,
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = self.pipeline(
                prompt,
                max_new_tokens=self.max_tokens, # 1024
                eos_token_id=terminators,
                do_sample=True,
                temperature=0.4,
                top_p=0.9,
            )

            respond = outputs[0]["generated_text"][len(prompt):]
            return respond
        
class Workflow:
    def __init__(self, model_id=None, api_key=None, api_base=None):
        self.model_id = model_id
        self.api_key = api_key
        self.api_base = api_base

        self.llm = LLM(self.model_id, self.api_key, self.api_base)
        # TODO
        #self.vdb = ChromaVdb(
        #    collection_name="solution_templates",
        #    model_id="intfloat/e5-mistral-7b-instruct", api_key="EMPTY", api_base='http://localhost:8001/v1/')
        #print(f"Loaded Documents: {len(self.vdb)}")


    def problem_distillation(self, question):
        return self.llm(distiller_prompt, question)
        
    def template_retrieve(self, distilled_question):
        #template = self.vdb.search(distilled_question)
        template = general_code_template
        return template
            
    def reasoner_instantiation(self, question, distilled_question, template):

        combined_input = combined_input_template.format(
            question=question,
            distilled_question=distilled_question,
            template=template)
        #print(f"Reasoning Prompt Len: {len(instantiation_prompt)+len(combined_input)}")

        temp_solution = self.llm(instantiation_prompt, combined_input)
        result = extract_checkmate_answer(temp_solution)
        return temp_solution, result
    
    def checkmate1_check(self, san, move):
        try:
            board = preprocess_san_input(san)
            move = move.split('.')[-1].strip()
            move = board.parse_san(move)
            board.push(move)
        except:
            return False, False
        return board.is_checkmate(), True
        
    def code_execution(self, sol_code, sat_code):

        start = time.time()
        #print("Start Exec")

        if "from typing" not in sol_code:
            sol_code = f"from typing import *\n{sol_code}"
        if "def sat" in sol_code:
            sat_code = ""

        sol_code = sol_code.replace("List[", "list[")
        sat_code = sat_code.replace("List[", "list[")

        exec_code = f"{sol_code}\n\n{sat_code}\n\n" + """answer = sol()
print("Answer:")
print(answer)
print("SAT:")
print(sat(answer))
"""
        output, status = run_with_timeout(execute_py_code, exec_code, len(sol_code.split('\n')))
        #print(f"Exec time: {time.time()-start}")
        if status == "timeout":
            temp_result = "Execution timeout. There might be infinite loops or recursions."
            temp_code = exec_code
        else:
            temp_result, temp_code = output
        answer, sat = extract_answer_and_sat(temp_result)
        solution_code = sol_code
        MODE = 0 # 0 means success, 1 means sat false, 2 means code err including infinite loops
        if not temp_result or 'An error occurred' in temp_result or status == "timeout":
            MODE = 2
        elif not "True" in sat:
            MODE = 1
        else:
            assert "True" in sat
            MODE = 0
        return MODE, temp_result

    def debugger(self, question, distilled_question, code, err_msg=None):
        question = debug_prompt.format(question=question, distilled_question=distilled_question, code=code)
        return self.llm(generate_debug_distiller, question)

    def result_selection(self, question, candidates):
        answer = self.llm(answer_selection, answer_selection_prompt.format(question=question, candidates='\n'.join(candidates)))
        return extract_checkmate_answer(answer)
        
    def helper_classify(self, gt, predicted, d2):
        if gt and predicted == 1:
            d2['tp'] += 1
        elif gt and predicted == -1:
            d2['fn'] += 1
        if not gt and predicted == 1:
            d2['fp'] += 1
        if not gt and predicted == -1:
            d2['tn'] += 1
        else:
            d2['unsure'] += 1
    
    def run(self, question, query=None, repeats=6):
        moves_san = query.split()
        moves_san = [move for move in moves_san if not move.endswith('.')]

        # Initialize a new chess board
        board = chess.Board()
        prev_fen = board.fen()
        
        # Apply the given moves to the board
        for move_san in moves_san:
            try:
                move = board.parse_san(move_san)
                board.push(move)

                prev_fen_str = prev_fen.split()[0]
                prev_fen_turn = 'White' if prev_fen.split()[1] == 'w' else 'Black'
                prompt = combined_input_template.format(fen=prev_fen, san=move_san, turn=prev_fen_turn)
                res = self.llm(instantiation_prompt, prompt)
                ans = extract_field(res, 'FEN')
                fen = board.fen()

                if fen!=ans:
                    print(ans)
                    print(fen)
                    import pdb;pdb.set_trace()
                
                prev_fen = fen

            except ValueError:
                assert False, "Invalid move in input"
        return board
        
        #prompt = combined_input_template.format(query=query, san=check_move)
        #res1 = self.llm(instantiation_prompt, prompt)
        #import pdb;pdb.set_trace()
