from openai import OpenAI

import json
import os
from prompt_generate.sample_generate import GeneratorBasedOnRules, SingleGeneratorBasedOnRules

from pydantic import BaseModel
from collections.abc import Sequence

from util.interface import RulesDict
from prompt_generate.util import logger

class QAPair(BaseModel):
    factor: str
    question: str

class Probes(BaseModel):
    scenario: str
    factor_question_pairs: list[QAPair]

class SingleScenarioProbeGenerator(SingleGeneratorBasedOnRules):
    def __init__(self, model: str, scenario: str | None, load_file: str, save_file: str, trajectory_file: str, feedback_times: int = 3, add_probes_into_samples: bool = False, samples_path: Sequence[str] | str | None = None, load_exist_probes: bool = False, reasoning_effort: str = "minimal"):
        super().__init__()
        self.model = model
        self.client = OpenAI()
        self._load_file = load_file
        rulesjson, self.scenario = self.load_file(load_file)
        if self.scenario is None and scenario is None:
            raise ValueError("When the scenario is not provided in the load file, the scenario should be provided.")
        if self.scenario is None:
            self.scenario = scenario # We will use the scenario in the load file first.
        assert isinstance(self.scenario, str)
        self.scenario = self._clean_scenario(self.scenario)
        
        self.roots = rulesjson.roots
        self.non_roots = rulesjson.non_roots
        self.rules: RulesDict = rulesjson.rules_convert_to_dict()
        self.roots_name = rulesjson.get_roots_names()
        self.non_roots_name = rulesjson.get_non_roots_names()
        
        self.save_file = save_file
        self.trajectory_file = trajectory_file
        self.add_probes_into_samples = add_probes_into_samples
        self.samples_path = [samples_path] if isinstance(samples_path, str) else samples_path
        self.load_exist_probes = load_exist_probes
        self.feedback_times = feedback_times
        self.reasoning_effort = reasoning_effort
        
    def _init_message(self):
        from prompt_generate.prompts.prompts_probe import developer_message
        self.messages = [
            {
                "role": "developer",
                "content": developer_message
            }
        ]

    def request_and_record(self, message: str) -> None | dict:
        
        self.messages += [{"role": "user", "content": message}]

        
        completion = self.client.responses.parse(
            model=self.model,                 
            input=self.messages,
            text_format=Probes,
            reasoning={"effort": self.reasoning_effort},   
            max_output_tokens=12800
        )

    
        self.update_usage(completion)

        
        parsed = getattr(completion, "output_parsed", None)
        if parsed is None:
            
            error_info = getattr(completion, "output_text", None) or "Model did not return a valid structured response."
            logger.warning(f"Failed to generate samples for scenario {self.scenario}. Error info: {error_info}")
            probes = None
            self.messages += [{"role": "assistant", "content": error_info}]
        else:
            probes = parsed.model_dump()
            self.messages += [{"role": "assistant", "content": json.dumps(probes, indent=4)}]

        self.save_trajectory()
        return probes
        
    def check_probes(self, probes: dict | None, factors: list[str]) -> tuple[bool, str]:
        if probes is None:
            return False, "No valid format found."
        if not isinstance(probes, dict):
            return False, f"The generated answer should be a dictionary but get {type(probes)}."
        if "scenario" not in probes:
            return False, "The scenario is not included in the answer."
        if "factor_question_pairs" not in probes:
            return False, "The factor_question_pairs is not included in the answer."
        for pair in probes["factor_question_pairs"]:
            if not isinstance(pair, dict):
                return False, f"The pair should be a dictionary but get {type(pair)}."
            if "factor" not in pair:
                return False, "The factor is not included in the pair."
            if "question" not in pair:
                return False, "The question is not included in the pair."
        return_factors = [pair["factor"] for pair in probes["factor_question_pairs"]]
        if set(return_factors) != set(factors):
            return False, f"The factors in the answer should be {factors} but get {return_factors}."
        return True, None
        
    def run(self):
        self._init_message()
        from prompt_generate.prompts.prompts_probe import get_user_message
        
        if self.load_exist_probes and os.path.exists(self.save_file):
            logger.info(f"The probes for the rule file {self._load_file} (scenario: \"{self.scenario}\") have already been generated. The existing probes will be loaded.")
            with open(self.save_file, "r") as f:
                probes = json.load(f)
            pass_check, error_info = self.check_probes(probes, self.roots_name + self.non_roots_name)
            if not pass_check:
                logger.error(f"Failed to generate valid probes for scenario {self.scenario}. Error info: {error_info}")
                return
            self.save(probes)
            return
        
        assert self.scenario is not None
        probes = self.request_and_record(get_user_message(scenario=self.scenario, factors=self.roots + self.non_roots))
        pass_check, error_info = self.check_probes(probes, self.roots_name + self.non_roots_name)
        have_feedback = 0
        while not pass_check and have_feedback < self.feedback_times:
            have_feedback += 1
            probes = self.request_and_record(error_info)
            pass_check, error_info = self.check_probes(probes, self.roots_name + self.non_roots_name)
        if not pass_check:
            logger.error(f"Failed to generate valid probes for scenario {self.scenario}. Error info: {error_info}")
            return
        self.save(probes)
        return 
        
    def save_trajectory(self):
        pass
        
    def save(self, probes):
        os.makedirs(os.path.dirname(self.save_file), exist_ok=True)
        with open(self.save_file, "w") as f:
            json.dump(probes, f, indent=4)
        logger.info(f"Probes for the rule file {self._load_file} (scenario: \"{self.scenario})\" have been saved in {self.save_file}.")
        if self.add_probes_into_samples:
            for samples_path in self.samples_path:
                if not os.path.exists((sample_file:=os.path.join(samples_path, os.path.split(self.save_file)[-1]))):
                    logger.warning(f"The samples file {sample_file} does not exist.")
                    continue
                with open(sample_file, "r") as f:
                    samples = json.load(f)
                if "probes" in samples:
                    logger.warning("The probes have already been added into the samples file. The newly generated probes will overwrite the old ones.")
                samples["probes"] = probes
                with open(sample_file, "w") as f:
                    json.dump(samples, f, indent=4)
                logger.info(f"The probes have been added into the samples file {sample_file}.")
                
class ProbeGenerator(GeneratorBasedOnRules):
    def __init__(self, model: str, load_path: str, save_path: str, trajectory_path: str, file_name: Sequence[str] | str | None = None, feedback_times: int = 3, add_probes_into_samples: bool = False, samples_path: Sequence[str] | str | None = None, load_exist_probes: bool = False, reasoning_effort: str = "minimal"):
        super().__init__(model = model, load_path=load_path, save_path=save_path, file_name=file_name, trajectory_path=trajectory_path, feedback_times=feedback_times)
        self.add_probes_into_samples = add_probes_into_samples
        if self.add_probes_into_samples and samples_path is None:
            raise ValueError("If `add_probes_into_samples`, the `samples_path` is needed.")
        self.samples_path = samples_path
        self.load_exist_probes = load_exist_probes
        self.reasoning_effort = reasoning_effort
    def _generate(self, file_name: str):
        scenario = self.load_scenario(file_name)
        generator = SingleScenarioProbeGenerator(
            model = self.model, 
            scenario = scenario,
            load_file = os.path.join(self.load_path, file_name),
            save_file = os.path.join(self.save_path, file_name),
            trajectory_file = os.path.join(self.trajectory_path, file_name),
            feedback_times = self.feedback_times,
            add_probes_into_samples = self.add_probes_into_samples,
            samples_path = self.samples_path,
            load_exist_probes = self.load_exist_probes,
            reasoning_effort = self.reasoning_effort
        )
        logger.info(f"Generate probes for the rule file {file_name}")
        generator.run()    

    def generate(self, num_works: int = 4):
        from multiprocessing import Pool
        with Pool(num_works) as p:
            p.map(self._generate, self.file_list)
            
def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str, default="gpt-4o-mini", help="The model name.")
    parser.add_argument("-l", "--load_path", type=str, default="dataset/rules", help="The path to load the files.")
    parser.add_argument("-s", "--save_path", type=str, default="dataset/probes", help="The path to save the files.")
    parser.add_argument("-f", "--file_name", type=str, nargs="*", help="The file names to generate the probes.")
    parser.add_argument("-t", "--trajectory_path", type=str, default="dataset/probe_trajectories", help="The path to save the trajectory.")
    parser.add_argument("--feedback_times", type=int, default=3, help="The times to request feedback.")
    parser.add_argument("-n", "--num_works", type=int, default=4, help="The number of processes to generate the probes.")
    parser.add_argument("--load_exist_probes", action="store_true", help="If true, the existing probes will be loaded.")
    parser.add_argument("--add_probes_into_samples", action="store_true", help="If true, the generated probes will be added into corresponding samples files (with the same name).")
    parser.add_argument("--samples_path", type=str, nargs="*", help="Used only if `add_probes_into_samples` is true.")
    parser.add_argument("--reasoning_effort", type=str, default="minimal",choices=["minimal", "medium", "high"],help="Reasoning effort level passed to the API.")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    generator = ProbeGenerator(
        model = args.model,
        load_path = args.load_path,
        save_path = args.save_path,
        file_name = args.file_name,
        trajectory_path = args.trajectory_path,
        feedback_times = args.feedback_times,
        load_exist_probes = args.load_exist_probes,
        add_probes_into_samples = args.add_probes_into_samples,
        samples_path = args.samples_path,
        reasoning_effort = args.reasoning_effort
    )
    generator.generate(args.num_works)
    logger.info("All probes are successfully generated.")