import re
import transformers
import torch
from openai import OpenAI


from utils.promptsg1 import distiller, \
    problem_solver, problem_solver_prompt, general_code_template, \
    evaluator, evaluator_prompt, \
    eval_distiller, eval_distiller_prompt, \
    final_evaluator, final_evaluator_prompt
from utils import execute_py_code, peel_py_mk_wrapper, extract_field, eval_for_GameOf24
from utils.bot_templates import bot_template_game24, bot_template_checkmate, bot_template_word_sorting
from utils.chroma import ChromaVdb

import multiprocessing
import time

MAXTOK = 1024
TEMP = 0.4

# 4.4%, any: 24%, all: 4.9%, tp: 1.1%
def eval_for_CheckmateInOne(output: str, target: str) -> bool:
    # Based on the input, determine the number of the last move
    if target[:-1] in output:
        return True
    return False

class LLM:
    def __init__(self, model_id, api_key=None, api_base=None):
        self.model_id = model_id

        if api_key is None:
            self.local = True
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model_id,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map = 'auto'
            )
        else:
            self.local = False
            self.client = OpenAI(
                # This is the default and can be omitted
                api_key=api_key,
                base_url=api_base
            )

    def __call__(self, system_prompt, user_prompt):
        messages = []

        if system_prompt:
            messages.append(
                {"role": "system", "content": system_prompt},
            )
        messages.append(
            {"role": "user", "content": user_prompt},
        )

        if not self.local:
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,

                max_tokens=MAXTOK,
                temperature=TEMP,
                #top_p=0.9,
            )
            
            # Extract and return the assistant's response
            respond = response.choices[0].message.content
            return respond
        else:
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                self.pipeline.tokenizer.eos_token_id,
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = self.pipeline(
                prompt,
                max_new_tokens=MAXTOK,
                eos_token_id=terminators,
                do_sample=True,
                temperature=TEMP,
                #top_p=0.9,
            )

            respond = outputs[0]["generated_text"][len(prompt):]
            return respond
        
class Workflow:
    def __init__(self, model_id=None, api_key=None, api_base=None):
        self.model_id = model_id
        self.api_key = api_key
        self.api_base = api_base

        self.llm = LLM(self.model_id, self.api_key, self.api_base)

    def run(self, question, query=None, target=None):
        l = []
        d = {
            'success': 0,
            'non-sat': 0,
            'error': 0,
 
            'num_attempts': 0,
            'corrected': 0,
            'wronged': 0,
            'others': 0,

            'tp': 0,
            'tn': 0,
            'fp': 0,
            'fn': 0,
            'unsure': 0,
            'selected': False
        }

        wrong_selected = False

        RETRY = 1
        TRIALS = 20
        LOOPS = int(TRIALS / (RETRY+1))

        all_ans = []
        all_cgt = []

        answer_extractor='You are a answer formatter. Your goal is to give user a clean formatted result from a unstructred plain text result. The desirable result should be in the format as the following example:\nAnswer:\nRg5#\n\nThe answer should only contain one SAN movement. Extract the result from the plain text answer exactly into this format, no more no less. Do not answer the question yourself.'
        answer_extractor_prompt="Question:\n{question}\nPlain text answer:\n{res}. Give the formatted answer:"

        for i in range(LOOPS):
            res = self.llm('You are a brilliant problem solver, try to solve the user questions correctly and cleverly. Do not do it brute forcely! Give only one most possible trial. If the problem is divisible, give critical reasoning steps.', question)
            field = extract_field(res, 'Answer')
            if not field:
                res2 = self.llm(answer_extractor, answer_extractor_prompt.format(question=question, res=res))
                field = extract_field(res2, 'Answer')

            all_ans.append((res, field))
            cgt = eval_for_CheckmateInOne(field, target)
            all_cgt.append(cgt)

            logic_evaluator = """
You are a evaluator who is good at chess.

Given a question and a solution to this problem, please verify whether this solution solves the problem 100 percent correctly.

Try to modify the answer if there are any errors.
Strictly follow the requirements in the question.

Verify step by step and at last give your modified result in the following format, e.g.:

Answer:
Rg5#
"""

            logic_evaluator_prompt = """
User's question:
{question}

Problem Solver's solution:
{res}

Give your analysis and modification:
"""
            for j in range(RETRY):
                #Strictly follow the requirements in the question, give your analysis and modification:
                logic_res = self.llm(logic_evaluator, logic_evaluator_prompt.format(question=question, res=res))
                logic_field = extract_field(logic_res, 'Answer')
                if not logic_field:
                    logic_res2 = self.llm(answer_extractor, answer_extractor_prompt.format(question=question, res=logic_res))
                    logic_field = extract_field(logic_res2, 'Answer')

                all_ans.append((logic_res, logic_field))
                new_cgt = eval_for_CheckmateInOne(logic_field, target)
                all_cgt.append(new_cgt)

                if cgt and not new_cgt:
                    d['wronged'] += 1
                elif not cgt and new_cgt:
                    d['corrected'] += 1
                else:
                    d['others'] += 1

        final_evaluator = """
You are a evaluator who is good at chess.

Given a question and a solution to this problem, please verify whether this solution solves the problem 100 percent correctly.

DO NOT try to solve or modify the answer.

Evaluate that solution step by step and at last give your evaluation result in two of the following options:

EVAL: CORRECT

or

EVAL: WRONG
"""

        final_evaluator_prompt = """
User's question:
{question}

Problem Solver's solution:
{distilled_answer}

Eval whether this solution 100 percent answer the question correctly?
"""
        for i, ((res, field), cgt) in enumerate(zip(all_ans, all_cgt)):
            cfield_cnt = 0
            NUM = 1
            MAJ = (NUM-1)/2
            for i in range(NUM):
                final_evaluator_text = final_evaluator_prompt.format(question=question, distilled_answer=field)
                #final_evaluator_text = final_evaluator_prompt.format(question=question, distilled_answer=res)
                correct = self.llm(final_evaluator, final_evaluator_text)

                cfield = extract_field(correct, 'EVAL')
                if not cfield:
                    correct2 = self.llm('You are a answer formatter. Your goal is to give user a clean formatted result from a unstructred plain text result. The desirable result should be in the format as the following example:\nEVAL: CORRECT\n\nThe answer should be only one line. Extract the result from the plain text answer exactly into this format, no more no less. Do not answer the question yourself.', f"Question:\n{final_evaluator_text}\nPlain text answer:\n{correct}. Give the formatted answer:")
                    cfield = extract_field(correct2, 'EVAL')
                
                #if "CORRECT" in cfield or "Correct" in cfield or "correct" in cfield:
                if "CORRECT" in cfield:
                    cfield_cnt += 1
            cfield = cfield_cnt > MAJ

            #cgt = eval_for_CheckmateInOne(field, target)

            if cgt:
                d['success'] += 1
                if cfield:
                    d['tp'] += 1
                    d['selected'] = True
            
            if not cfield and not cgt:
                d['tn'] += 1
            
            if cfield and not cgt:
                print("FP")
                if not d['selected']:
                    print("wrong selected")
                    wrong_selected = True
                print(correct)
                d['non-sat'] += 1
                d['fp'] += 1
            elif not cfield and cgt:
                print("FN")
                print(correct)
                d['fn'] += 1
            
            d['num_attempts'] += 1
            #d['others'] += 1

        if wrong_selected:
            d['selected'] = False
        d['loop'] = LOOPS
        return d