from collections.abc import Sequence
import os
import json
from openai import OpenAI

from pydantic import BaseModel

from prompt_generate.util import logger

class SamplesForOneCompositionSeperate(BaseModel):
    factors: list[str]
    value: list[str]
    sentences: list[str]

class SamplesForOneComposition(BaseModel):
    value: list[bool]
    samples: list[str]
    
class SamplesForOneCompositionWithResults(BaseModel):
    factor_value: list[bool]
    result_value: list[bool]
    samples: list[str]
    
class SamplesForAllComposition(BaseModel):
    factors: list[str]
    compositions: list[SamplesForOneComposition]
    
class SamplesForAllCompositionWithResults(BaseModel):
    factors: list[str]
    results: list[str]
    compositions: list[SamplesForOneCompositionWithResults]
    
from prompt_generate.prompt_generate import TokenUsageMixin
from prompt_generate.util import RulesJson, FactorWithoutType

from util.interface import RulesDict

class SingleGeneratorBasedOnRules(TokenUsageMixin):
    
    def __init__(self):
        super().__init__()
        self._init_usage()  # fix: initiate total_usage

    def load_file(self, file_dir: str) -> tuple[RulesJson, str] :
        with open(file_dir) as f:
            j = json.load(f)
        scenario: str = j.pop('scenario')
        rulesjson = RulesJson.model_validate(j) # extra fields are ignored automatically
        
        return rulesjson, scenario
    
    def _clean_scenario(self, scenario: str) -> str:
        scenario = scenario.replace("(", "").replace(")", "")
        scenario = scenario.strip()
        return scenario

class SingleScenarioSampleGeneratorBase(SingleGeneratorBasedOnRules):
    
    def __init__(self, model:str, scenario: str | None, load_file: str, save_file: str, trajectory_file: str, num_sent: int, feedback_times: int = 3, contain_results: bool = False, reasoning_effort: str = "minimal"):
        super().__init__()
        self.model = model
        self.client = OpenAI()
        rulesjson, self.scenario = self.load_file(load_file)
        self.roots = rulesjson.roots
        self.roots_name = rulesjson.get_roots_names()
        self.non_roots = rulesjson.non_roots
        self.non_roots_name = rulesjson.get_non_roots_names()
        self.rules: RulesDict = rulesjson.rules_convert_to_dict()
        if self.scenario is None and scenario is None:
            raise ValueError("When the scenario is not provided in the load file, the scenario should be provided.")
        if self.scenario is None:
            self.scenario = scenario # We will use the scenario in the load file first.
        assert isinstance(self.scenario, str)
        self.scenario = self._clean_scenario(self.scenario)
        self.save_file = save_file
        self.trajectory_file = trajectory_file
        self.num_sent = num_sent
        self.feedback_times = feedback_times
        self.contain_results = contain_results
        self.reasoning_effort = reasoning_effort

class SingleScenarioSampleGeneratorTogether(SingleScenarioSampleGeneratorBase):
    
    def _init_causal_model(self):
        if not hasattr(self, "causal_model") or self.causal_model is None:
            from util.interface import CausalModel
            self.causal_model = CausalModel(roots=self.roots_name, non_roots=self.non_roots_name, rules=self.rules, scenario=self.scenario)
        else:
            pass
    
    def _init_message(self):
        from prompt_generate.prompts.prompts_sample import developer_message
        self.messages = [{"role": "developer", "content": developer_message}]
    
#    def save_samples(self, samples):
#        samples_save = {}
#        samples_save["scenario"] = self.scenario
#        print("111111111111111111111111111111111111111", samples["factors"])
#        print("111111111111111111111111111111111111111", self.roots_name)
#        assert set(samples["factors"]) == set(self.roots_name)
#        if self.contain_results:
#            assert set(samples["results"]) == set(self.non_roots)
#        samples_save["roots"] = samples["factors"]
#        samples_save["non_roots"] = samples["results"] if self.contain_results else self.non_roots
#        samples_save["rules"] = self.rules
#        samples_save["compositions"] = self._add_non_roots_value_if_needed(samples["compositions"])
#        
#        os.makedirs(os.path.dirname(self.save_file), exist_ok=True)
#        if os.path.exists(self.save_file):
#            logger.debug(f"File {self.save_file} will be overwritten.")
#        with open(self.save_file, 'w') as f:
#            json.dump(samples_save, f, indent=4)
#        logger.info(f"Samples are saved in {self.save_file}")
    def save_samples(self, samples):
        samples_save = {}
        samples_save["scenario"] = self.scenario

        # 统一投影成字符串名称，避免 Factor 混入导致 set/hash/json 问题
        factor_names = [
            f if isinstance(f, str) else getattr(f, "name", str(f))
            for f in samples.get("factors", [])
        ]


        # 和 roots_name（已是字符串列表）比较
        assert set(factor_names) == set(self.roots_name)

        if self.contain_results:
            result_names = [
                r if isinstance(r, str) else getattr(r, "name", str(r))
                for r in samples.get("results", [])
            ]
            
            assert set(result_names) == set(self.non_roots_name)
            samples_save["non_roots"] = result_names
        else:
            
            samples_save["non_roots"] = self.non_roots_name

        samples_save["roots"] = factor_names
        samples_save["rules"] = self.rules
        samples_save["compositions"] = self._add_non_roots_value_if_needed(samples["compositions"])

        os.makedirs(os.path.dirname(self.save_file), exist_ok=True)
        if os.path.exists(self.save_file):
            logger.debug(f"File {self.save_file} will be overwritten.")
        with open(self.save_file, "w") as f:
            
            json.dump(samples_save, f, indent=4)
        logger.info(f"Samples are saved in {self.save_file}")
    
    def _add_non_roots_value_if_needed(self, compositions):
        if self.contain_results:
            return compositions.copy()
        else:
            self._init_causal_model()
            new_compositions = []
            for comp in compositions:
                comp["result_value"] = self.causal_model.get_non_root_value({factor: value for factor, value in zip(self.roots_name, comp["value"])})
                comp["factor_value"] = comp["value"]
                comp.pop("value")
                new_compositions.append(comp)
        return new_compositions
            
    def save_trajectory(self):
        os.makedirs(os.path.dirname(self.trajectory_file), exist_ok=True)
        with open(self.trajectory_file, 'w') as f:
            json.dump({"scenario": self.scenario, "usage": self.total_usage, "trajectory": self.messages}, f, indent=4)
    
    def check_samples(self, samples: dict | None) -> tuple[bool, str | None]:
        with_results: bool = self.contain_results
        
        if samples is None:
            return False, "No valid return json."
        if not isinstance(samples, dict):
            return False, "The return file is not a dict."
        if "factors" not in samples:
            return False, "Key \"factors\" is not found in the return file."
        if with_results and "results" not in samples:
            return False, "Key \"results\" is not found in the return file."
        if "compositions" not in samples:
            return False, "Key \"compositions\" is not found in the return file."
        if not isinstance(samples["factors"], list):
            return False, f"The type of key \"factors\" should be list but get {type(samples['factors'])}."
        if not isinstance(samples["compositions"], list):
            return False, f"The type of key \"compositions\" should be list but get {type(samples['compositions'])}."
        if with_results and not isinstance(samples["results"], list):
            return False, f"The type of key \"results\" should be list but get {type(samples['results'])}."
        if len(samples["factors"]) != len(self.roots_name):
            return False, f"The number of factors should be {len(self.roots_name)} but get {len(samples['factors'])}."
        if with_results and len(samples["results"]) != len(self.non_roots_name):
            return False, f"The number of results should be {len(self.non_roots_name)} but get {len(samples['results'])}."
        for factor in samples["factors"]:
            if not isinstance(factor, str):
                return False, "The type of factors should be str."
            if factor not in self.roots_name:
                return False, f"Factor {factor} is not found in list {self.roots_name}."
        for root in self.roots_name:
            if root not in samples["factors"]:
                return False, f"Factor {root} is not found in the return file \"factors\" list {samples['factors']}."
        if with_results:
            for result in samples["results"]:
                if not isinstance(result, str):
                    return False, "The type of results should be str."
                if result not in self.non_roots_name:
                    return False, f"Result {result} is not found in list {self.non_roots_name}."
            for non_root in self.non_roots_name:
                if non_root not in samples["results"]:
                    return False, f"Result {non_root} is not found in the return file \"results\" list {samples['results']}."
            
        from itertools import product
        all_comp = list(product([True, False], repeat=len(self.roots)))
        
        def check_comp(comp, all_comp):
            if not isinstance(comp, dict):
                return False, "The type of composition should be dict.", all_comp
            if "value" not in comp:
                return False, "Key \"value\" is not found in the composition.", all_comp
            if "samples" not in comp:
                return False, "Key \"samples\" is not found in the composition.", all_comp
            if not isinstance(comp["value"], list):
                return False, f"The type of key \"value\" should be list but get {type(comp['value'])}.", all_comp
            if not isinstance(comp["samples"], list):
                return False, f"The type of key \"samples\" should be list but get {type(comp['samples'])}.", all_comp
            if len(comp["value"]) != len(self.roots):
                return False, f"The number of values should be {len(self.roots)} but get {len(comp['value'])}.", all_comp
            if tuple(comp["value"]) not in all_comp:
                return False, f"The value {comp['value']} has occurred in the return file more than once.", all_comp
            else:
                all_comp.remove(tuple(comp["value"]))
            for v in comp["value"]:
                if not isinstance(v, bool):
                    return False, f"The type of value {v} should be bool but get {type(v)}.", all_comp
            if len(comp["samples"]) != self.num_sent:
                return False, f"The number of samples should be {self.num_sent} but get {len(comp['samples'])}." , all_comp
            for s in comp["samples"]:
                if not isinstance(s, str):
                    return False, f"The type of sample {s} should be str but get {type(s)}.", all_comp
            return True, None, all_comp
        
        def check_comp_with_results(comp, all_comp):
            if not isinstance(comp, dict):
                return False, "The type of composition should be dict.", all_comp
            if "factor_value" not in comp:
                return False, "Key \"factor_value\" is not found in the composition.", all_comp
            if "result_value" not in comp:
                return False, "Key \"result_value\" is not found in the composition.", all_comp
            if "samples" not in comp:
                return False, "Key \"samples\" is not found in the composition.", all_comp
            if not isinstance(comp["factor_value"], list):
                return False, f"The type of key \"factor_value\" should be list but get {type(comp['factor_value'])}.", all_comp
            if not isinstance(comp["result_value"], list):
                return False, f"The type of key \"result_value\" should be list but get {type(comp['result_value'])}.", all_comp
            if not isinstance(comp["samples"], list):
                return False, f"The type of key \"samples\" should be list but get {type(comp['samples'])}.", all_comp
            if len(comp["factor_value"]) != len(self.roots):
                return False, f"The number of factor values should be {len(self.roots)} but get {len(comp['factor_value'])}.", all_comp
            if tuple(comp["factor_value"]) not in all_comp:
                return False, f"The value {comp['factor_value']} has occurred in the return file more than once.", all_comp
            else:
                logger.debug(f"The factor value {comp["factor_value"]} has been checked.")
                all_comp.remove(tuple(comp["factor_value"]))
            if len(comp["result_value"]) != len(self.non_roots):
                return False, f"The number of result values should be {len(self.non_roots)} but get {len(comp['result_value'])}.", all_comp
            for v in comp["factor_value"]:
                if not isinstance(v, bool):
                    return False, f"The type of factor value {v} should be bool but get {type(v)}.", all_comp
            for v in comp["result_value"]:
                if not isinstance(v, bool):
                    return False, f"The type of result value {v} should be bool but get {type(v)}.", all_comp
            for s in comp["samples"]:
                if not isinstance(s, str):
                    return False, f"The type of sample {s} should be str but get {type(s)}.", all_comp
            self._init_causal_model()
            expected_results = self.causal_model.get_non_root_value({factor: value for factor, value in zip(self.roots_name, comp["factor_value"])})
            for result, value in zip(self.non_roots_name, comp["result_value"]):
                if value != expected_results[result]:
                    return False, f"The value of result {result} given condition {comp['factor_value']} should be {expected_results[result]} but get {value}.", all_comp
            return True, None, all_comp
        
        for c in samples["compositions"]:
            if not with_results:
                pass_check, error_info, all_comp = check_comp(c, all_comp)
            else:
                pass_check, error_info, all_comp = check_comp_with_results(c, all_comp)
            if not pass_check:
                return False, error_info
        if all_comp:
            return False, f"The value composition(s) {all_comp} are not found in the return file."
        
        return True, None
    
    def request_and_record(self, message: str) -> None | dict:
        
        self.messages.append({"role": "user", "content": message})

        completion = self.client.responses.parse(
            model=self.model,   
            input=self.messages,  
            text_format=(
                SamplesForAllComposition
                if not self.contain_results
                else SamplesForAllCompositionWithResults
            ),
            reasoning={"effort": self.reasoning_effort},  
            max_output_tokens=25600,
        )

        self.update_usage(completion)
        
        parsed = getattr(completion, "output_parsed", None)
        if parsed is None:
            
            error_info = getattr(completion, "output_text", None) or "Model did not return a valid structured response."
            samples = None
            self.messages.append({"role": "assistant", "content": error_info})
            from prompt_generate.util import logger
            logger.warning(f"Failed to generate samples for scenario \"{self.scenario}\". Error info: {error_info}")
        else:
            samples = parsed.model_dump()
            self.messages.append({"role": "assistant", "content": json.dumps(samples, indent=4)})

        
        self.save_trajectory()
        return samples
    
        
    def run(self) -> None:
        if os.path.exists(self.save_file) and not self.force_regen:
            logger.info(f"The samples have been generated at {self.save_file}. Skip the generation.")
            return 
        
        self._init_message()
        
        from human_utils import MathSymbolConverter
        rule_strings = MathSymbolConverter()(self.rules)
        
        from prompt_generate.prompts.prompts_sample import get_user_message
        assert isinstance(self.scenario, str)
        samples = self.request_and_record(get_user_message(
                        scenario = self.scenario, 
                        factors = [root.remove_type() for root in self.roots], 
                        non_roots = [non_root.remove_type() for non_root in self.non_roots],
                        num_sent = self.num_sent, 
                        together = True,
                        contain_results = self.contain_results,
                        rules = rule_strings))
            
        pass_check, error_info = self.check_samples(samples)
        regen_time = 0
        while not pass_check and regen_time < self.feedback_times:
            regen_time += 1
            logger.warning(f"Generation for scenario \"{self.scenario}\" fails to pass the check. Error info: {error_info}")
            # update messages and request again
            samples = self.request_and_record(f"Your last response is not correct. The error info: {error_info}\n Please try again.")
            pass_check, error_info = self.check_samples(samples)   
        if not pass_check:
            logger.warning(f"Generation for scenario \"{self.scenario}\" fails to pass the check at the **LAST** time try. Error info: {error_info}")
            
        if samples is not None:
            self.save_samples(samples)
        else:
            logger.error(f"Failed to generate samples for scenario \"{self.scenario}\".")
            
class SingleScenarioSampleGeneratorSeperate(SingleScenarioSampleGeneratorBase):
    def __init__(self):
        raise NotImplementedError

from typing import Literal
class GeneratorBasedOnRules:
    def __init__(self, model: str, load_path: str, save_path: str, trajectory_path: str, file_name: str | Sequence[str] | None, feedback_times: int = 3, scenario_file: str | None = None):
        self.model = model
        self.load_path = os.path.abspath(load_path)
        self.save_path = os.path.abspath(save_path)
        self.trajectory_path = os.path.abspath(trajectory_path)
        self.file_list = self.load_file_list(file_name)
        if scenario_file is None:
            self.scenarios: list[str] | None = None
        else:
            with open(scenario_file) as f:
                self.scenarios = [line.strip() for line in f.readlines() if line.strip()]
        self.feedback_times = feedback_times
                
    def load_file_list(self, file_name: str | Sequence[str] | None) -> list[str]:
        files_in_load_path = [f for f in os.listdir(self.load_path) if f.endswith('.json')]
        if file_name is None:
            return files_in_load_path
        elif isinstance(file_name, str):
            file_name = [file_name]
        elif isinstance(file_name, Sequence):
            file_list = []
            for f in file_name:
                f = f if f.endswith('.json') else f + '.json'
                if not f in files_in_load_path:
                    raise FileNotFoundError(f"The file \"{f}\" is not found in dir {self.load_path}.")
                file_list.append(f)
            return file_list
        else:
            raise TypeError(f"Invalid type of file_name: {type(file_name)}")
        
    def load_scenario(self, file_name: str) -> str | None:
        if ((sce_idx:=file_name[:-5]).isnumeric()) or ((sce_idx:=file_name[:-5].split("_")[0]).isnumeric()) and self.scenarios is not None:
            scenario = self.scenarios[int(sce_idx)]
        else:
            scenario = None
        return scenario

class SampleGenerator(GeneratorBasedOnRules):
    def __init__(self, model: str, load_path: str, save_path: str, trajectory_path: str, file_name: str | Sequence[str] | None, num_sent: int, feedback_times: int = 3, contain_results: Literal[True, False, 'both']= 'both', together: bool = True, scenario_file: str | None = None, force_regen: bool = False, reasoning_effort: str = "minimal"):
        """If the scenario_file is None, the load_file should contain a key "scenario"."""
        super().__init__(model=model, load_path=load_path, save_path=save_path, trajectory_path=trajectory_path, file_name=file_name, feedback_times=feedback_times, scenario_file=scenario_file)
        self.num_sent = num_sent
        self.contain_results = contain_results
        self.together = together
        self.force_regen = force_regen
        self.reasoning_effort = reasoning_effort
    def _generate_one(self, file_name):
        scenario = self.load_scenario(file_name)
        if self.together:
            generator = SingleScenarioSampleGeneratorTogether(
                model = self.model,
                scenario = scenario,
                load_file = os.path.join(self.load_path, file_name),
                save_file = os.path.join(self.save_path, file_name),
                trajectory_file=os.path.join(self.trajectory_path, file_name),
                num_sent = self.num_sent,
                feedback_times=self.feedback_times,
                contain_results=self.contain_results if self.contain_results != 'both' else False,
                reasoning_effort = getattr(self, "reasoning_effort", "minimal")
            )
            generator.force_regen = self.force_regen
            logger.info(f"Generate samples for scenario \"{scenario}\" "+("with" if self.contain_results == True else "without") +" results.")
            generator.run()
            if self.contain_results == 'both':
                load_path = self.load_path
                save_path = os.path.join(os.path.split(self.save_path)[0], os.path.split(self.save_path)[1]+"_with_results")
                trajectory_path = os.path.join(os.path.split(self.trajectory_path)[0], os.path.split(self.trajectory_path)[1]+"_with_results")
                generator = SingleScenarioSampleGeneratorTogether(
                    model = self.model,
                    scenario = scenario,
                    load_file = os.path.join(load_path, file_name),
                    save_file = os.path.join(save_path, file_name),
                    trajectory_file=os.path.join(trajectory_path, file_name),
                    num_sent = self.num_sent,
                    feedback_times=self.feedback_times,
                    contain_results=True,
                    reasoning_effort = getattr(self, "reasoning_effort", "minimal")
                )
                generator.force_regen = self.force_regen
                logger.info(f"Generate samples for scenario \"{scenario}\" with results.")
                generator.run()
        else:
            raise NotImplementedError            
        
    def generate(self, num_worker: int = 4):
        from multiprocessing import Pool
        with Pool(num_worker) as p:
            p.map(self._generate_one, self.file_list)
        
def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="Generate samples for given scenarios.")
    parser.add_argument("-m", "--model", type=str, default="gpt-4o", help="The model name.")
    parser.add_argument("-l", "--load_path", type=str, default="dataset/rules", help="The path to load the rules.")
    parser.add_argument("-s", "--save_path", type=str, default="dataset/samples", help="The path to save the samples.")
    parser.add_argument("-t", "--trajectory_path", type=str, default="dataset/trajectories", help="The path to save the trajectory.")
    parser.add_argument( "--file_name", type=str, nargs="*", help="The name of the file to generate samples.")
    parser.add_argument("-n", "--num_sent", type=int, default=10, help="The number of sentences to generate for each composition.")
    parser.add_argument( "--feedback_times", type=int, default=3, help="The number of feedback times.")
    parser.add_argument("-r","--contain_results", type=str, default="both", choices=["true", "false", "both"], help="Whether the samples contain results.")
    parser.add_argument("--sep", action="store_true", help="Whether the samples are generated seperately.")
    parser.add_argument("--scenario_file", type=str, default=None, help="The file contains the scenarios.")
    parser.add_argument("-w", "--num_worker", type=int, default=4, help="The number of workers.")
    parser.add_argument("-f", "--force_regen", action="store_true", help="Force to regenerate the samples, otherwise skip the existed samples.")
    parser.add_argument("--reasoning_effort", type=str, default="minimal",choices=["minimal", "medium", "high"],help="Reasoning effort level for the API.")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    generator = SampleGenerator(
        model = args.model,
        load_path = os.path.abspath(args.load_path),
        save_path = os.path.abspath(args.save_path),
        trajectory_path = os.path.abspath(args.trajectory_path),
        file_name = args.file_name,
        num_sent = args.num_sent,
        feedback_times = args.feedback_times,
        contain_results = True if (args.contain_results == "true") else (False if args.contain_results == "false" else 'both'),
        together = not args.sep,
        scenario_file = args.scenario_file,
        force_regen = args.force_regen,
        reasoning_effort = args.reasoning_effort
    )
    generator.generate(args.num_worker)
