# coding:utf-8

from abc import ABC, abstractmethod
from functools import partial
from sentence_transformers import SentenceTransformer, util
import copy
import json
import numpy as np
import re
import spacy
from typing import List
import logging
import traceback
import sys
import random

from GeneralLLM import LargeLanguageModel, Qwen, ChatGPT
from question import MultipleChoiceQuestion


def fix_json_string_1(json_string):
    # Ensure keys are properly quoted
    json_string = re.sub(r"(\breasoning\b|\banswer\b):", r'"\1":', json_string)

    # Wrap unquoted reasoning values with double quotes
    json_string = re.sub(
        r'(?<="reasoning":\s)([^{\[\"].*?)(?=,?\s*"?answer")',
        lambda match: f'"{match.group(0).strip()}"',
        json_string,
        flags=re.DOTALL,
    )

    # Add a comma if missing between reasoning and answer
    json_string = re.sub(
        r'("reasoning":.*?[^}])\s*("answer":)', r"\1, \2", json_string, flags=re.DOTALL
    )

    return json_string


def fix_json_string(json_string):
    # Ensure keys are properly quoted
    json_string = re.sub(r"(\breasoning\b|\banswer\b):", r'"\1":', json_string)

    # Wrap unquoted reasoning values with double quotes
    json_string = re.sub(
        r'(?<="reasoning":\s)([^"\[{].*?)(?=,?\s*"?answer")',
        lambda match: f'"{match.group(0).strip()}"',
        json_string,
        flags=re.DOTALL,
    )

    # Add a comma if missing between reasoning and answer
    json_string = re.sub(
        r'("reasoning":.*?[^}],?)\s*("answer":)',
        r"\1, \2",
        json_string,
        flags=re.DOTALL,
    )

    return json_string


def get_mcq_llm_answer(mcq: MultipleChoiceQuestion, llm: LargeLanguageModel) -> tuple:
    """Get the answer of an LLM to a multiple-choice question.
    Args:
        mcq:MultipleChoiceQuestion, the question to answer
        llm:LargeLauguageModel, the model that answer the question
    Return:
        List[bool]: the list that indicates whether each of the
            options is selected by the model.
        str: the original response of the model.
    """
    prompt = mcq.get_prompt()
    response_ok = False
    max_retry = 3
    n_retry = 1
    result = [False] * len(mcq.correct)
    llm.refresh()
    while response_ok is False and n_retry <= max_retry:
        try:
            original_response = ""
            if n_retry > 1:
                original_response = llm.listen_and_response(
                    "Failed to extract proper json block as requested"
                )
            else:
                original_response = llm.listen_and_response(prompt)
            if random.random() < 0.05:
                logging.info("original response: " + original_response)
            response = re.sub(r"\n", " ", original_response)
            response = re.findall(r'[{]\s*"[^{]*[}]', response)[0]
            # response = re.sub(r'("[^"]*"\s*:\s*[^,}]+)(?=\s*"[^{]*[}])', r'\1,', response)
            if mcq.text_type == "choice":
                try:
                    response = json.loads(response)
                except:
                    response = fix_json_string(response)
                    response = json.loads(response)

                for i in range(len(mcq.option_ids)):
                    if mcq.option_ids[i] in response["answer"]:
                        result[i] = True
            elif mcq.text_type == "judgement":
                response = re.sub(r"\s+True", ' "True"', response)
                response = re.sub(r"\s+False", ' "False"', response)
                response = json.loads(response)
                oid2pos = {}
                for i in range(len(mcq.option_ids)):
                    oid2pos[mcq.option_ids[i]] = i
                for key in response.keys():
                    result[oid2pos[key]] = eval(response[key])
            else:
                logging.error(
                    f"get_mcq_llm_answer: Invalid text type '{mcq.text_type}'"
                )
            response_ok = True
        except Exception as ex:
            
            logging.error(f"original_llm_answer = {original_response}")
            logging.error(traceback.format_exc())
            logging.error(
                f"get_mcq_llm_answer: Format error, try again. n_retry = {n_retry}"
            )
            n_retry += 1
    return result, original_response


if __name__ == "__main__":
    mcq = MultipleChoiceQuestion(
        question="__________ memory is the aspect of memory that is involved in the recall of information acquired within the past few hours to days.",
        option_ids=["A", "B", "C", "D"],
        options=["Working", "Sensory", "Long-term", "Prospective"],
        question_first=True,
        correct=[False, False, False, True],
    )

    llm = ChatGPT(
        name="chatgpt",
        description="The chatgpt assistant.",
        model="gpt-4-turbo",
        temperature=1.0,
    )

    print(f"Original question: {mcq.get_prompt()}")
