# Licensed under the MIT license.

import sys
import random
sys.path.append(".")
import numpy as np, os, random, json, math
from tqdm import trange
from typing import List, Dict, Tuple
from copy import deepcopy

try:
    from rapidfuzz import fuzz, process
except:
    pass

from xiaobu_rag import xiaobuRetrieval
from models.IO_System import IO_System
from common.utils import read_txt, read_json,read_json_str
from eval_src.Evaluator import Evaluator, GSM8KEvaluator
from MCTS_backbone import MCTS_Searcher, MCTS_Node
from run_src.utils import (
    Node_Type,
    reach_terminal_subquestion,
    reach_terminal_ost_step,
    concat_subqs_and_subas,
    concat_ost_steps,
    concat_subqs_subas_as_ost_steps,
    make_hint,
    make_response_prefix,
    split_user_question,
    print_tree_from_root,
    stochastic_find_best_solution,
    concat_solution_trace,
)

COUNT_CHILD = 4

def verbose_print(s: str, verbose: bool):
    if verbose:
        print(s)


class Generator:
    """Generator generates children nodes"""

    def __init__(self, args, tokenizer, model, evaluator: Evaluator) -> None:
        self.io = IO_System(args, tokenizer, model)
        self.evaluator = evaluator
 
        self.max_tokens = args.max_tokens
        self.enable_potential_score = args.enable_potential_score
        self.RagRetrieval = xiaobuRetrieval()
        self.medextractor_prompt = read_txt("prompts/DDX-MCTS/medextractor.txt")
        self.a1_prompt = read_txt("prompts/DDX-MCTS/key_symptoms_extrac.txt")
        self.a2_prompt = read_txt("prompts/DDX-MCTS/hypo_generate.txt")
        self.a3_prompt = read_txt("prompts/DDX-MCTS/evidence_ver.txt")
        self.a4_prompt = read_txt("prompts/DDX-MCTS/deductive_analy.txt")

    def _extract_from_cache(self, subquestion_list: List[str]):
        high_score_questions = []
        selected_answers = []
        values = []
        low_score_questions = []
        low_score_values = []
        low_score_answers_list = []
        unmatched_questions = []

        for subquestion in subquestion_list:
            best_match = process.extractOne(subquestion, self.reasoning_cache.keys(), scorer=fuzz.ratio)

            if best_match:
                best_question, best_score = best_match[0], best_match[1]
                similarity = best_score / 100
                cache_entry = self.reasoning_cache[best_question]
                score = cache_entry["score"]
                if similarity == 1:
                    if score >= 0.9:
                        high_score_questions.append(best_question)
                        selected_answers.append(cache_entry["selected_answer"])
                        values.append(score)
                    else:
                        low_score_questions.append(best_question)
                        low_score_values.append(score)
                        low_score_answers_list.append(cache_entry["answer_list"])
                else:
                    unmatched_questions.append(subquestion)
            else:
                unmatched_questions.append(subquestion)

        return {
            "high_score_questions": high_score_questions,
            "selected_answers": selected_answers,  # most likely answer corresponding to each subquestion
            "values": values,
            "low_score_questions": low_score_questions,
            "low_score_values": low_score_values,
            "low_score_answers_list": low_score_answers_list,
            "unmatched_questions": unmatched_questions,
        }

    
   
 
    def simpel_generate(self,io_input,stop_tokens = [],num_return=1):
        io_output_list = self.io.generate(
            io_input,
            num_return=num_return,
            max_tokens=self.max_tokens,
            stop_tokens=stop_tokens,
        )
        return io_output_list[0]
  
    
    def generate_reward(self,io_input,num_return=1):
        io_output_list = self.io.generate(
            io_input,
            num_return=num_return,
            max_tokens=self.max_tokens,
            stop_tokens=self.fewshot_cot_config["stop_tokens"],
        )
        return io_output_list[0]

class Reasoning_MCTS_Node(MCTS_Node):
    def __init__(
        self,
        parent: "Reasoning_MCTS_Node",
        depth: int,
        node_type: Node_Type,
        verbose: bool = False,

        generator: Generator = None,
       
        user_question: str = None,
        max_depth_allowed: int = None,
        expected_answer: str = None,
        # --- Add ---
        key_features : List[str] = None,
        hypothesis : str = None,
        relevant_symptom:str = None,
        supply_hint:str ="",
        
        # --- For node selection (not in sanity checks yet) ---
        enable_potential_score: bool = None,
        potential_answers: List[str] = None,
    ) -> None:
        """params:
        subquestion: the node is proposing a new subquestion
        subanswer: the answer corresponding to the new subquestion the node proposed
        re_subanswer: the node is proposing a new subanswer to the parent's subquestion
        """
        super().__init__()


        #! attributes
        self.parent = parent  # if parent is None, then the node is the root
        self.children: List["Reasoning_MCTS_Node"] = []
        self.depth = depth
        self.node_type = node_type

        # --- Add ---
        self.key_features = key_features
        self.hypothesis = hypothesis
        self.relevant_symptom = relevant_symptom

        if parent is None:  # root
            self.verbose = verbose
            self.user_question = user_question
            self.expected_answer = expected_answer
            self.generator = generator
            self.max_depth_allowed = max_depth_allowed
            self.enable_potential_score = enable_potential_score

            prompt = self.generator.medextractor_prompt.replace("{{Patient's Verbal Description}}",self.user_question)
            output = self.generator.simpel_generate(io_input=prompt)
            
            output = read_json_str(output)
            self.general_features = output["General features"]
            self.clinical_features = output["Clinical features"]
            self.supply_hint = ""

        else:  # inherit from parent
            self.verbose = parent.verbose
            self.user_question = parent.user_question
            self.expected_answer = parent.expected_answer
            self.generator = parent.generator
            self.max_depth_allowed = parent.max_depth_allowed
            self.enable_potential_score = parent.enable_potential_score
            self.general_features = parent.general_features
            self.clinical_features = parent.clinical_features
            self.supply_hint = parent.supply_hint
        if len(supply_hint):
            self.supply_hint+="\n"+supply_hint
        
        #! record solution trace from root to the current node. key: subquestion id
        # if parent is None:  # root
        #     self.solution_trace: Dict[int, Dict[str, str]] = {0: {"user_question": user_question, "ost_step": {}}}
        # else:
        #     self.solution_trace = deepcopy(parent.solution_trace)

        #     if node_type is Node_Type.REPHRASED_USER_QUESTION:
        #         self.solution_trace[0]["user_question"] = rephrased_user_question
        #     elif node_type is Node_Type.DIRECT_ANSWER:
        #         self.solution_trace[self.subquestion_counter]["direct_answer"] = {
        #             "text": direct_answer,
                    
        #         }
        #     elif node_type is Node_Type.SUBQUESTION:
        #         self.solution_trace[self.subquestion_counter] = {
        #             "subquestion": subquestion,
        #             "subanswer": {"text": subanswer},
        #             "ost_step": {},
        #         }
        #     elif node_type is Node_Type.RAGQUESTION:
        #         self.solution_trace[self.subquestion_counter] = {
        #             "subquestion": subquestion,
        #             "subanswer": {"text": subanswer},
        #             "ost_step": {},
        #         }    
        #     elif node_type is Node_Type.RE_SUBANSWER:
        #         self.solution_trace[self.subquestion_counter]["subanswer"] = {"text": re_subanswer}
        #     elif node_type is Node_Type.OST_STEP:
        #         self.solution_trace[self.subquestion_counter]["ost_step"][self.ost_step_counter] = ost_step
       
        #! potential_score for intermediate nodes (only used for node selection)
        if self.enable_potential_score:
            self.potential_answers = potential_answers
            self.potential_score = 0
            if parent is None:  # root
                self.potential_answers_history = {}
            else:
                self.potential_answers_history = deepcopy(parent.potential_answers_history)
                self.potential_answers_history[self.depth] = potential_answers

    def __str__(self) -> str:
        type2str = {
            Node_Type.USER_QUESTION: "U",
            Node_Type.KEY_SYMPTOM:"KS",
            Node_Type.HYPO_GENERATE:"HG",
            Node_Type.EVIDENCE_VERI:"EV",
            Node_Type.DEDUCTIVE_ANYLY:"DA",
            Node_Type.FINAL_ANSWER:"FA",
            
        }
        return f"{type2str[self.node_type]}-{self.id}"

    def _create_children(self):

        
        if self.node_type is Node_Type.FINAL_ANSWER:
            raise ValueError("FINAL_ANSWER node cannot create children!!")
        
        def do_action_a1():
            print("do_action_a1")
            prompt = self.generator.a1_prompt.\
                replace("{{Patient's General Features}}",str(self.general_features))\
                .replace("{{Patient's Clinical Features}}",str(self.clinical_features))
            
            output = self.generator.simpel_generate(io_input=prompt+self.supply_hint)
            key_features = read_json_str(output)['Key features']
            
            self.children.append(
                    Reasoning_MCTS_Node(
                        parent=self,
                        depth=self.depth + 1,
                        node_type=Node_Type.KEY_SYMPTOM,
                        key_features=key_features,
                    )
                )
            
        def do_action_a2():
            print("do_action_a2")
            prompt = self.generator.a2_prompt.\
                replace("{{Patient's General Features}}",str(self.general_features))\
                .replace("{{Patient's Key Features}}",str(self.key_features))\
                .replace("{{Retrieved Symptom-Disease Triples from Knowledge Graph}}","None")
            
            output = self.generator.simpel_generate(io_input=prompt+self.supply_hint)
            hypothesis = read_json_str(output)['Hypothesis']
            
            self.children.append(
                    Reasoning_MCTS_Node(
                        parent=self,
                        depth=self.depth + 1,
                        node_type=Node_Type.HYPO_GENERATE,
                        key_features=self.parent.key_features,
                        hypothesis = hypothesis
                    )
                )
        def do_action_a3():
            print("do_action_a3")
            prompt = self.generator.a3_prompt.\
                replace("{{Patient's General Features}}",str(self.general_features))\
                .replace("{{Patient's Key Features}}",str(self.key_features))\
                .replace("{{Current Hypothesis}}",str(self.hypothesis))\
                .replace("{{Retrieved Symptom-Disease Triples from Knowledge Graph}}","None")

            output = self.generator.simpel_generate(io_input=prompt+self.supply_hint)
            relevant_symptom = read_json_str(output)['Relevant symptom']
            
            self.children.append(
                    Reasoning_MCTS_Node(
                        parent=self,
                        depth=self.depth + 1,
                        node_type=Node_Type.EVIDENCE_VERI,
                        key_features=self.key_features,
                        hypothesis = self.hypothesis,
                        relevant_symptom = relevant_symptom
                    )
                )
        def do_action_a4():
            prompt = self.generator.a4_prompt.\
                replace("{{Patient’s Verbal Description}}",str(self.user_question))\
                .replace("{{Current Hypothesis}}",str(self.hypothesis))\
                .replace("{{Clinical Indicator to be Verified}}",str(self.relevant_symptom))

            output = self.generator.simpel_generate(io_input=prompt+self.supply_hint)
            output = read_json_str(output)
            # breakpoint()
            if output["Existence"] == "Exist" and output["Certainty"] =="Confident":
                self.children.append(
                        Reasoning_MCTS_Node(
                            parent=self,
                            depth=self.depth + 1,
                            node_type=Node_Type.FINAL_ANSWER,
                            hypothesis = self.hypothesis,
                            # relevant_symptom = relevant_symptom
                        )
                    )
            elif output["Existence"] == "Non-exist" and output["Certainty"] =="Confident":
                self.children.append(
                        Reasoning_MCTS_Node(
                            parent=self,
                            depth=self.depth + 1,
                            node_type=Node_Type.KEY_SYMPTOM,
                            supply_hint =str(self.hypothesis) +" has been ruled exclude."
                            #
                            # relevant_symptom = relevant_symptom
                        )
                    )
            else:
                self.children.append(
                        Reasoning_MCTS_Node(
                            parent=self,
                            depth=self.depth + 1,
                            node_type=Node_Type.HYPO_GENERATE,
                            hypothesis= self.hypothesis,
                            supply_hint = str(self.relevant_symptom)+" has already been examined and cannot be diagnosed with the disease."
                            # relevant_symptom = relevant_symptom
                        )
                    )
            # breakpoint()
        for _ in range(COUNT_CHILD):
            if self.node_type is Node_Type.USER_QUESTION:
                do_action_a1()
            if self.node_type is Node_Type.KEY_SYMPTOM:
                do_action_a2()
            if self.node_type is Node_Type.HYPO_GENERATE:
                do_action_a3()
            if self.node_type is Node_Type.EVIDENCE_VERI:
                do_action_a4()
       
        return self.children

    def is_valid_leaf_node(self):
        return self.node_type is Node_Type.FINAL_ANSWER

    def is_valid_solution_node(self):
        #! a valid solution can only be in SUBQUESTION type or DIRECT_ANSWER type or OST_STEP type
        # 除叶子结点外，ost-step也是合法solution
        if self.is_valid_leaf_node():
            return True
        

    def set_potential_score(self, score: float):
        self.potential_score = score

    def find_children(self, rollout_id: int):
        self.children = self.children or self._create_children()
        for child in self.children:
            child.set_rollout_id(rollout_id)
        return self.children

    def is_terminal(self):
        return self.depth >= self.max_depth_allowed or self.is_valid_leaf_node()

    def calculate_reward(self):
        if self.is_valid_leaf_node():
            
            solution_trace, final_step, _, reward= concat_solution_trace(self.solution_trace)
            reward = 0
            with open('prompts/evaluation/evalution_prompt_zh.txt', 'r', encoding='utf-8') as file:
                eval_template = file.read()
            eval_input = eval_template.replace("{description}",self.user_question).replace("{answer}",solution_trace)
            result = self.generator.generate_reward(eval_input)
            
            result = result.split("{")[-1].split("}")[0]
            result = "{"+result+"}"
            try:
                eval_result = json.loads(result)
                for k in eval_result:
                    print(reward)
                    reward +=eval_result[k]
                print(reward/4.0)
            except:
                print("评估格式识别失败")
                print(result)
                reward = random.randint(0, 20)
            
            return reward/4.0
        else:
            return 0

    def skip_backprop(self):
        return self.node_type is Node_Type.USER_QUESTION or self.node_type is Node_Type.REPHRASED_USER_QUESTION


def search_for_answers(args, user_question: str, question_id: int, gt_answer: str, generator: Generator):
    verbose_print(
        f"********************* Searching for answers to question {question_id} ********************* ", args.verbose
    )

    #! build an MCTS searcher
    mcts_searcher = MCTS_Searcher(
        exploration_weight=args.mcts_exploration_weight,
        weight_scheduler=args.mcts_weight_scheduler,
        num_rollouts=args.num_rollouts,
        discount=args.mcts_discount_factor,
        verbose=args.verbose,
    )

    #! build the MCTS tree
    root_node = Reasoning_MCTS_Node(
        parent=None,
        depth=0,
        node_type=Node_Type.USER_QUESTION,
        verbose=args.verbose,
        generator=generator,
        
        user_question=user_question,
        expected_answer=gt_answer,
        max_depth_allowed=args.max_depth_allowed,
        
        enable_potential_score=args.enable_potential_score,
    )

    model_solutions = []
    model_all_solutions = []
    model_rollout_nodes = []
    for i in (pbar := trange(args.num_rollouts, disable=True, position=0)):
        rollout_node = mcts_searcher.do_rollout(root_node, i)
        model_rollout_nodes.append(rollout_node)

        _, best_solution, _, chosen_node, all_solution_nodes, all_solutions = stochastic_find_best_solution(
            root_node, generator.evaluator, enable_potential_score=args.enable_potential_score
        )
        model_solutions.append(best_solution)
        model_all_solutions.append(all_solutions)

        if args.save_tree:
            with open(
                os.path.join(
                    args.answer_sheets_dir,
                    f"Question {question_id:04d} - Rollout {i}.tree",
                ),
                "w",
            ) as f:
                print_tree_from_root(
                    mcts_searcher=mcts_searcher,
                    rollout_id=i,
                    root_node=root_node,
                    chosen_node=chosen_node,
                    file=f,
                )

    #! record final traces
    if all_solution_nodes:
        js = [{"trace": node.solution_trace, "rollout_id": node.rollout_id} for node in all_solution_nodes]
        with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Final Solutions.json"), "w") as f:
            json.dump(js, f,ensure_ascii=False)
    # if model_rollout_nodes:
    #     js2 = [{"trace": node.solution_trace, "rollout_id": i} for i, node in enumerate(model_rollout_nodes)]
    #     with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Rollout Solutions.json"), "w") as f:
    #         json.dump(js2, f,ensure_ascii=False)

    if args.enable_potential_score:
        js = [node.potential_answers_history for node in all_solution_nodes]
        with open(os.path.join(args.answer_sheets_dir, f"Question {question_id:04d} - Potentials.json"), "w") as f:
            json.dump(js, f,ensure_ascii=False)

    return model_solutions, i, model_all_solutions
