from functools import wraps
from typing import Dict, List, Any, Tuple
import json
from experiment.utils import (
    extract_json,
    extract_xml,
    extract_boxed,
    calculate_depth,
    score_math,
    score_mc,
    score_mh,
    score_code,
    clean_code
)
from llm import (
    gen,
    get_model,
    set_model,
    set_log,
    need_extracted,
)
from experiment.prompter import math, multichoice, multihop, code

failure_count = 0
iteration_count = 0
MAX_RETRIES = 3
ATOM_DEPTH = 2

score = None
module = None
prompter = None

def get_iter_count():
    global iteration_count
    return iteration_count

def set_module(module_name):  # math, multi-choice, multi-hop, code
    global module, prompter, score
    module = module_name
    if module == "math":
        prompter = math
        score = score_math
    elif module == "multi-choice":
        prompter = multichoice
        score = score_mc
    elif module == "multi-hop":
        prompter = multihop
        score = score_mh
    elif module == "code":
        prompter = code
        score = score_code

def save_dag(new_log=None, path="dag.json"):
    try:
        with open(path, 'r', encoding='utf-8') as file:
            data = json.load(file)
            if not isinstance(data, list):
                data = []
    except (FileNotFoundError, json.JSONDecodeError):
        data = []
    index = len(data)
    data.append({index: new_log})
    with open(path, 'w', encoding='utf-8') as file:
        json.dump(data, file, ensure_ascii=False, indent=4)


async def label(dag_text) -> Dict[str, list]:
    for _ in range(MAX_RETRIES):
        prompt = prompter.label(dag_text)
        response = await gen(prompt, model="gpt-4o-mini", response_format="json_object", log_token=False)
        result = extract_json(response)
        if prompter.check("label", result):
            return result.get("thoughts", [])
            # try:
            #     with open("dag.json", 'r', encoding='utf-8') as file:
            #         data = json.load(file)
            #         if not isinstance(data, list):
            #             data = []
            # except (FileNotFoundError, json.JSONDecodeError):
            #     data = []
            # index = len(data)
            # new_obj = {index: result.get("thoughts", [])}
            # data.append(new_obj)
            # with open("dag.json", 'w', encoding='utf-8') as file:
            #     json.dump(data, file, ensure_ascii=False, indent=4)
            # return


class ModuleMath:

    async def direct(self, question: str, **kwargs):
        for _ in range(MAX_RETRIES):
            prompt = prompter.direct(question)
            response = await gen(prompt, response_format="text")
            answer = extract_boxed(response)
            result = {"answer": answer} if answer else {}
            if prompter.check("direct", result):  # answer key in result
                result["response"] = response
                result["message"] = [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response},
                ]
                return result
        return {}
    
    async def ensemble(self, question: str, solutions: list, **kwargs):
        for _ in range(MAX_RETRIES):
            prompt = prompter.ensemble(question, solutions)
            response = await gen(prompt, response_format="text")
            answer = extract_boxed(response)
            result = {"answer": answer} if answer else {}
            if prompter.check("ensemble", result):  # answer key in result
                result["response"] = response
                result["message"] = [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response},
                ]
                return result
        return {}
    
    async def decompose(self, direct_result: dict = None, question: str = None, **kwargs):
        if not direct_result:
            direct_result = await self.direct(question=question, **kwargs)
        message = direct_result["message"].copy()
        message += [{"role": "user", "content": prompter.decompose()}]
        response = await gen(message, response_format="text")
        message += [{"role": "assistant", "content": response}]
        result = extract_xml(response)
        result["response"] = response
        result["message"] = message
        if direct_result:
            result["answer"] = direct_result["answer"]
        return result

    async def contract(self, decompose_result: dict = None, **kwargs):
        if not decompose_result:
            decompose_result = await self.decompose(**kwargs)
        message = decompose_result["message"][-2:]
        message += [{"role": "user", "content": prompter.contract()}]
        for _ in range(MAX_RETRIES):
            response = await gen(message, response_format="text")
            result = extract_xml(response)
            if prompter.check("contract", result):  # question key in result
                result["response"] = response
                result["message"] = message
                return result
        return {}
    
    async def atom(self, question: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, dict]]:
        log = {}
        logs = kwargs.get("logs", [])
        dags = kwargs.get("dags", [])
        depth = kwargs.get("depth", 0)
            
        if not kwargs.get("direct_result"):
            direct_result = await self.direct(question=question, **kwargs)
        else:
            direct_result = kwargs.get("direct_result")
        
        if not kwargs.get("decompose_result"):
            kwargs_copy = kwargs.copy()
            if "direct_result" in kwargs_copy:
                del kwargs_copy["direct_result"]
            decompose_result = await self.decompose(direct_result=direct_result, **kwargs_copy)
        else:
            decompose_result = kwargs.get("decompose_result")

        dag_list = await label(direct_result.get('response', '') + '\n' + decompose_result.get('response', ''))
        new_obj = {f"iter_{depth} question": question, "thoughts": dag_list}
        dags.append(new_obj)

        kwargs["decompose_result"] = decompose_result
        kwargs["direct_result"] = direct_result
        contracted_question = (await self.contract(**kwargs))["question"]
        contract_result = await self.direct(contracted_question, **kwargs)
        
        # 实现ensemble方法，对三个结果进行集成
        ensemble_result = await self.ensemble(question, [direct_result, decompose_result, contract_result])
        ensemble_answer = ensemble_result.get("answer", "")
        
        scores = [score(question, direct_result), score(question, decompose_result), score(question, contract_result), score(question, ensemble_result)]
        
        scores = []
        if all(result.get("answer", "") == ensemble_answer for result in [direct_result, decompose_result, contract_result]):
            scores = [1, 1, 1]
        else:
            for result in [direct_result, decompose_result, contract_result]:
                scores.append(score(result.get("answer", ""), ensemble_answer))
        
        # Update log with results
        log.update({
            "iter": depth,
            "scores": scores,
            "direct": direct_result,
            "decompose": decompose_result,
            "contract": contract_result
        })
        
        # Select best method based on scores
        methods = {
            2: ("contract", contract_result),
            0: ("direct", direct_result),
            1: ("decompose", decompose_result),
            -1: ("ensemble", ensemble_result)
        }
        
        max_score_index = scores.index(max(scores))
        method, result = methods.get(max_score_index, methods[-1])
        log["method"] = method
        logs.append(log)

        if depth < ATOM_DEPTH and method == "contract":
            if depth == 0:
                global iteration_count
                iteration_count += 1
            kwargs["direct_result"] = contract_result
            kwargs["decompose_result"] = None
            kwargs["depth"] = depth + 1
            kwargs["dags"] = dags
            kwargs["logs"] = logs
            return await self.atom(question=contracted_question, **kwargs)

        save_dag(new_log=dags)
        return {
            "method": method,
            "response": result.get("response"),
            "answer": result.get("answer"),
        }, logs


class ModuleCode:

    async def direct(self, question: str, contexts: str = None, **kwargs):
        for _ in range(MAX_RETRIES):
            if contexts:
                prompt = prompter.direct(question=question, contexts=contexts)
            else:
                prompt = prompter.direct(question)
            response = await gen(prompt, response_format="text")
            result = extract_xml(response)
            if prompter.check("direct", result):  # answer key in result
                result["response"] = response
                result["message"] = [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response},
                ]
                return result
        return {}
    
    async def multistep(self, question: str, **kwargs):
        for _ in range(MAX_RETRIES):
            prompt = prompter.multistep(question)
            response = await gen(prompt, response_format="text")
            result = extract_xml(response)
            if prompter.check("multistep", result):  # answer key in result
                result["response"] = response
                result["message"] = [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response},
                ]
                return result
        return {}

    async def ensemble(self, question: str, solutions: list, **kwargs):
        for _ in range(MAX_RETRIES):
            prompt = prompter.ensemble(question, solutions)
            response = await gen(prompt, response_format="text")
            result = extract_xml(response)
            if prompter.check("ensemble", result):  # answer key in result
                result["response"] = response
                result["message"] = [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response},
                ]
                return result
        return {}
    
    ''' AST version
    async def decompose(self, direct_result: dict = None, question: str = None, **kwargs):
        if not direct_result:
            direct_result = await self.direct(question=question, **kwargs)
        return prompter.decompose(direct_result["answer"])
    '''

    # llm version
    async def decompose(self, direct_result: dict = None, question: str = None, **kwargs):
        # if not direct_result:
        #     direct_result = await self.direct(question=question, **kwargs)

        multistep_result = await self.multistep(question=question, **kwargs)
        message = multistep_result["message"].copy()
        message += [{"role": "user", "content": prompter.decompose()}]
        response = await gen(message, response_format="text")
        message += [{"role": "assistant", "content": response}]
        result = extract_xml(response)
        result["response"] = response
        result["message"] = message
        if direct_result:
            result["answer"] = direct_result["answer"]
        return result, multistep_result

    async def contract(self, decompose_result: dict = None, **kwargs):
        if not decompose_result:
            decompose_result = await self.decompose(**kwargs)
        prompt = prompter.contract(decompose_result.get("response", ""), kwargs.get("test_cases"))
        message = [{"role": "user", "content": prompt}]
        for _ in range(MAX_RETRIES):
            response = await gen(message, response_format="text")
            message += [{"role": "assistant", "content": response}]
            result = extract_xml(response)
            if prompter.check("contract", result):  # question key in result
                result["response"] = response
                result["message"] = message
                return result
        return {}
    
    async def atom(self, question: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, dict], List[str]]:
        log = {}
        logs = kwargs.get("logs", [])
        dags = kwargs.get("dags", [])
        depth = kwargs.get("depth", 0)

        if not kwargs.get("direct_result"):
            direct_result = await self.direct(question=question, **kwargs)
        else:
            direct_result = kwargs.get("direct_result")
        
        if not kwargs.get("decompose_result"):
            kwargs_copy = kwargs.copy()
            if "direct_result" in kwargs_copy:
                del kwargs_copy["direct_result"]
            decompose_result, multistep_result = await self.decompose(direct_result=direct_result, **kwargs_copy)
        else:
            decompose_result = kwargs.get("decompose_result")

        dag_list = await label(decompose_result['response'])
        new_obj = {f"iter_{depth} question": question, "thoughts": dag_list}
        dags.append(new_obj)

        kwargs["decompose_result"] = decompose_result
        kwargs["direct_result"] = direct_result
        contract_result = await self.contract(**kwargs)
        next_iter_question = contract_result["question"]
        trans_test_cases = [line.strip() for line in contract_result["test"].strip().split('\n')]
        contract_result = await self.direct(contract_result["question"], contexts=trans_test_cases[0])
        
        # 实现ensemble方法，对多个模型结果进行集成
        ensemble_result = await self.ensemble(question, [
            direct_result.get("answer", ""), 
            multistep_result.get("answer", ""),
            contract_result.get("answer", "")
        ])
        ensemble_index = str(ensemble_result["answer"])
        
        log.update({
            "iter": depth,
            "direct": direct_result,
            "decompose": decompose_result,
            "contract": contract_result
        })

        if '0' in ensemble_index:
            method, result, test_cases = ("direct", direct_result, kwargs.get("test_cases"))
        elif '1' in ensemble_index:
            method, result, test_cases = ("decompose", multistep_result, kwargs.get("test_cases"))
        elif '2' in ensemble_index:
            method, result, test_cases = ("contract", contract_result, trans_test_cases)
        else:
            raise ValueError("select index out of range(3)")
        log["method"] = method
        logs.append(log)
        
        if depth < ATOM_DEPTH and method == "contract":
            if depth == 0:
                global iteration_count
                iteration_count += 1
            kwargs["direct_result"] = contract_result
            kwargs["decompose_result"] = None
            kwargs["depth"] = depth + 1
            kwargs["test_cases"] = trans_test_cases
            kwargs["contexts"] = trans_test_cases[0]
            kwargs["dags"] = dags
            kwargs["logs"] = logs
            return await self.atom(question=next_iter_question, **kwargs)

        # answer_code = await clean_code(result.get("answer"))
        save_dag(new_log=dags)

        answer_code = result.get("answer")
        if method == "contract":
            if answer_code.count("assert") == len(trans_test_cases):
                file_str = f"{answer_code}\n\n"
            else:
                file_str = f"{answer_code}\n\n" + "# Test cases\n\n" + "\n".join(trans_test_cases)
        else:
            if answer_code.count("assert") == len(kwargs.get("test_cases")):
                file_str = f"{answer_code}\n\n"
            else:
                file_str = f"{answer_code}\n\n" + "# Test cases\n\n" + "\n".join(kwargs.get("test_cases"))
        answer_code = await clean_code(file_str)

        return {
            "method": method,
            "response": result.get("response"),
            "answer": answer_code,
        }, logs, test_cases
        

class ModuleMhop:

    async def direct(question: str, contexts: str = None, **kwargs):
        for _ in range(MAX_RETRIES):
            if contexts:
                prompt = prompter.direct(question=question, contexts=contexts)
            else:
                raise ValueError("Multi-hop task need contexts")
            response = await gen(prompt, response_format="text")
            result = extract_xml(response)
            if prompter.check("direct", result):  # answer key in result
                result["response"] = response
                result["message"] = [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response},
                ]
                return result
        return {}
    
    async def ensemble(self, question: str, solutions: list, **kwargs):
        for _ in range(MAX_RETRIES):
            prompt = prompter.ensemble(question, solutions)
            response = await gen(prompt, response_format="text")
            result = extract_xml(response)
            if prompter.check("ensemble", result):  # answer key in result
                result["response"] = response
                result["message"] = [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response},
                ]
                return result
        return {}

