import re
import transformers
import torch
from openai import OpenAI


from utils.promptsg1 import distiller, \
    problem_solver, problem_solver_prompt, general_code_template, \
    evaluator, evaluator_prompt, \
    eval_distiller, eval_distiller_prompt, \
    final_evaluator, final_evaluator_prompt
from utils import execute_py_code, peel_py_mk_wrapper, extract_field, eval_for_exact_matching_with_no_punctuation
from utils.bot_templates import bot_template_game24, bot_template_checkmate, bot_template_word_sorting
from utils.chroma import ChromaVdb

import multiprocessing
import time


MAXTOK = 1024
TEMP = 0.4

NO_DEBUGGER = True

# 43.6%, naive 33.8%, any 43.6%, tn 61.3%. corrected 3.4%, wronged 7%
class LLM:
    def __init__(self, model_id, api_key=None, api_base=None):
        self.model_id = model_id

        if api_key is None:
            self.local = True
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model_id,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map = 'auto'
            )
        else:
            self.local = False
            self.client = OpenAI(
                # This is the default and can be omitted
                api_key=api_key,
                base_url=api_base
            )

    def __call__(self, system_prompt, user_prompt):
        messages = []

        if system_prompt:
            messages.append(
                {"role": "system", "content": system_prompt},
            )
        messages.append(
            {"role": "user", "content": user_prompt},
        )

        if not self.local:
            response = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,

                max_tokens=MAXTOK,
                temperature=TEMP,
                #top_p=0.9,
            )
            
            # Extract and return the assistant's response
            respond = response.choices[0].message.content
            return respond
        else:
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )

            terminators = [
                self.pipeline.tokenizer.eos_token_id,
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = self.pipeline(
                prompt,
                max_new_tokens=MAXTOK,
                eos_token_id=terminators,
                do_sample=True,
                temperature=TEMP,
                #top_p=0.9,
            )

            respond = outputs[0]["generated_text"][len(prompt):]
            return respond
        
class Workflow:
    def __init__(self, model_id=None, api_key=None, api_base=None):
        self.model_id = model_id
        self.api_key = api_key
        self.api_base = api_base

        self.llm = LLM(self.model_id, self.api_key, self.api_base)

    def run(self, question, query=None, target=None):
        l = []
        d = {
            'success': 0,
            'non-sat': 0,
            'error': 0,
 
            'num_attempts': 0,
            'corrected': 0,
            'wronged': 0,
            'others': 0,

            'tp': 0,
            'tn': 0,
            'fp': 0,
            'fn': 0,
            'unsure': 0,
            'selected': False,
            'first_correct': False,
        }

        wrong_selected = False
        any_selected = False

        RETRY = 1
        TRIALS = 20
        LOOPS = int(TRIALS / (RETRY+1))

        all_ans = []
        all_cgt = []

        answer_extractor = 'You are a answer formatter. Your goal is to give user a clean formatted result from a unstructred plain text result. The desirable result should be in the format as the following example:\nAnswer:\nababdon bell critus\n\n Extract the result from the plain text answer exactly into this format, no more no less. Do not answer the question yourself.'
        answer_extractor_prompt = "Question:\n{question}\nPlain text answer:\n{res}. Give the formatted answer:"

        for i in range(LOOPS):
            res = self.llm('You are a brilliant problem solver, try to solve the user questions correctly and cleverly. Give only one most possible trial.', question)
            field = extract_field(res, 'Answer')
            if not field:
                res2 = self.llm(answer_extractor, answer_extractor_prompt.format(question=question, res=res))
                field = extract_field(res2, 'Answer')
            
            all_ans.append((res, field))

            cgt = eval_for_exact_matching_with_no_punctuation(query, field, target)

            all_cgt.append(cgt)

            logic_evaluator = """
You are a evaluator who is good at reasoning and math.

Given a question and a solution to this problem, please verify whether the analysis inside the solution is correctly leading to the correct answer.

Try to modify the answer if there are any errors.
Strictly follow the requirements in the question.

Verify step by step and at last give your modified result in the following format, e.g.:

Answer:
abandon bell critus

"""

            logic_evaluator_prompt = """
User's question:
{question}

Problem Solver's solution:
{res}

Give your analysis and modification:
"""
            for j in range(RETRY):
                #Strictly follow the requirements in the question, give your analysis and modification:
                logic_res = self.llm(logic_evaluator, logic_evaluator_prompt.format(question=question, res=res))
                logic_field = extract_field(logic_res, 'Answer')
                if not logic_field:
                    logic_res2 = self.llm(answer_extractor, answer_extractor_prompt.format(question=question, res=logic_res))
                    logic_field = extract_field(logic_res2, 'Answer')

                all_ans.append((logic_res, logic_field))

                new_cgt = eval_for_exact_matching_with_no_punctuation(query, logic_field, target)
                
                all_cgt.append(new_cgt)

                if cgt and not new_cgt:
                    d['wronged'] += 1
                elif not cgt and new_cgt:
                    d['corrected'] += 1
                else:
                    d['others'] += 1

        #How to analysis: compare the words in the solution two at a time, e.g. given B A D C, the provided solution is A B C D. Your analysis is to first compare A and B, then B and C, and at last C and D.
        final_evaluator = """
You are a evaluator who is good at reasoning and math.

Given a question and a solution to this problem, please verify whether this solution solves the problem correctly.
You analyze step-by-step and make sure that your final answer is consistent with your analysis!
Do not show your expected answer, just analyze the given answer and the question.

DO NOT try to solve or modify the answer.

At last give your evaluation result in two of the following options:

EVAL: CORRECT

or

EVAL: WRONG
"""

        final_evaluator_prompt = """
User's question:
{question}

Problem Solver's solution:
{distilled_answer}

Eval whether this solution answer the question correctly?
"""
        for i, ((res, field), cgt) in enumerate(zip(all_ans, all_cgt)):

            #final_evaluator_text = final_evaluator_prompt.format(question=question, distilled_answer=field)
            #correct = self.llm(final_evaluator, final_evaluator_text)

            cfield_cnt = 0
            NUM_EVAL = 1
            NUM_MAJOR = (NUM_EVAL - 1) / 2
            for j in range(NUM_EVAL):
                final_evaluator_text = final_evaluator_prompt.format(question=question, distilled_answer=field)
                #final_evaluator_text = final_evaluator_prompt.format(question=question, distilled_answer=res)
                correct = self.llm(final_evaluator, final_evaluator_text)

                cfield = extract_field(correct, 'EVAL')
                if not cfield:
                    correct2 = self.llm('You are a answer formatter. Your goal is to give user a clean formatted result from a unstructred plain text result. The desirable result should be in the format as the following example:\nEVAL: CORRECT\n\nThe answer should be only one line. Extract the result from the plain text answer exactly into this format, no more no less. Do not answer the question yourself.', f"Question:\n{final_evaluator_text}\nPlain text answer:\n{correct}. Give the formatted answer:")
                    cfield = extract_field(correct2, 'EVAL')
                
                #if "CORRECT" in cfield or "Correct" in cfield or "correct" in cfield:
                if "CORRECT" in cfield:
                    cfield_cnt += 1
            cfield = cfield_cnt > NUM_MAJOR
            
            #cgt = eval_for_exact_matching_with_no_punctuation(query, field, target)

            if i == 0:
                d['first_correct'] = cgt
            any_selected = cfield or any_selected

            if cgt:
                d['success'] += 1
                if cfield:
                    d['tp'] += 1
                    d['selected'] = True
            
            if not cfield and not cgt:
                d['tn'] += 1
            
            if cfield and not cgt:
                print("FP")
                if not d['selected']:
                    print("wrong selected")
                    wrong_selected = True
                print(field)
                print(correct)
                d['non-sat'] += 1
                d['fp'] += 1
            elif not cfield and cgt:
                print("FN")
                print(field)
                print(correct)
                d['fn'] += 1

            
            d['num_attempts'] += 1

        if wrong_selected:
            d['selected'] = False
        if not any_selected:
            d['selected'] = d['first_correct']
        d['loop'] = LOOPS
        return d