import pandas as pd
from pathlib import Path
import shutil
from util.scenario import Scenario
from sample.sampler import Sampler
from util.interface import CausalModel, PromptGenerator, VideoGenerator, VideoAsker
from util.util import get_video_name

class Handler:
    def __init__(self, model: CausalModel,
                 prompt_generator: PromptGenerator=None,
                 video_generator: VideoGenerator=None,
                 video_asker: VideoAsker=None,
                 sample_n=1,
                 llm_name: str="sample"):
        """
        sample_n: number of samples for each value of each non_root.
            Totally there are 2 * sample_n * len(non_roots) samples.
        """
        self.model = model
        self.roots = model.roots
        if self.model.non_roots:
            non_roots = model.topo_sorted_non_roots()
            self.non_roots = non_roots
            self.variables = self.model.roots + non_roots
        self.sampler = Sampler(model=model, llm_name=llm_name)
        self.sample_n = sample_n
        self.prompt_generator = prompt_generator
        self.video_generator = video_generator
        self.video_asker = video_asker

        self.llm_name = llm_name
        self.database_path = Path(__file__).resolve().parent.parent / 'database' / llm_name
        self.database_path.mkdir(parents=True, exist_ok=True)
        scenario_index = Scenario.get_index(self.model.scenario)
        self.scenario_path = self.database_path / str(scenario_index)
        self.scenario_path.mkdir(parents=True, exist_ok=True)
        self.scenario_index = scenario_index

        self.prompts_path = Path(__file__).resolve().parent.parent / "prompts"
        self.prompts_path.mkdir(parents=True, exist_ok=True)

    def generate_prompts(self, **kwargs):
        # level 2 and level 3 samples
        handle_indexs = []
        for non_root in self.non_roots:
            sample_index = self.sampler.get_sample_index(non_root=non_root, repeat=self.sample_n)
            false_index, true_index = sample_index
            handle_indexs += false_index + true_index
        group_indexs = self.sampler.get_sample_index_level_2()
        handle_indexs += group_indexs["sample_id"].tolist()
        handle_indexs = sorted(list(set(handle_indexs)))
        for sample_id in handle_indexs:
            row = self.sampler.read_sample(sample_id=sample_id)
            if not pd.isna(row["prompt"]):
                continue
            values = {name: row["true_" + name] for name in self.roots}
            prompt = self.prompt_generator.generate_prompt(scenario=self.model.scenario, variables=values, **kwargs)
            self.sampler.update_sample(sample_id, "prompt", prompt)
        # level 1 samples
        sample_ids = self.sampler.get_sample_indexs_text_consistency()
        for sample_id in sample_ids:
            row = self.sampler.read_sample(sample_id=sample_id, select_sample="text")
            if not pd.isna(row["prompt"]):
                continue
            values = {name: row["true_" + name] for name in self.variables}
            prompt = self.prompt_generator.generate_prompt(scenario=self.model.scenario, variables=values, **kwargs)
            self.sampler.update_sample(sample_id, "prompt", prompt, select_sample="text")

    def generate_videos(self, **kwargs):
        # level 2 and level 3 samples
        handle_indexs = []
        for non_root in self.non_roots:
            sample_index = self.sampler.get_sample_index(non_root=non_root, repeat=self.sample_n)
            false_index, true_index = sample_index
            handle_indexs += false_index + true_index
        group_indexs = self.sampler.get_sample_index_level_2()
        handle_indexs += group_indexs["sample_id"].tolist()
        handle_indexs = sorted(list(set(handle_indexs)))
        for sample_id in handle_indexs:
            row = self.sampler.read_sample(sample_id=sample_id)
            if pd.isna(row["prompt"]):
                raise ValueError(f"Prompt not exist for scenario {self.scenario_index} and sample_id {sample_id}.")
            video_name = get_video_name(scenario_id=self.scenario_index, sample_id=sample_id)
            file_path = self.scenario_path / video_name
            if file_path.exists():
                continue
            self.video_generator.generate_video_from_prompt(prompt=row["prompt"],
                                                            save_path=file_path, **kwargs)
        # level 1 samples
        sample_ids = self.sampler.get_sample_indexs_text_consistency()
        for sample_id in sample_ids:
            row = self.sampler.read_sample(sample_id=sample_id, select_sample="text")
            if pd.isna(row["prompt"]):
                raise ValueError(f"Prompt not exist for scenario {self.scenario_index} and sample_id {sample_id}.")
            video_name = get_video_name(scenario_id=self.scenario_index, sample_id=sample_id, select_sample="text")
            file_path = self.scenario_path / video_name
            if file_path.exists():
                continue
            self.video_generator.generate_video_from_prompt(prompt=row["prompt"],
                                                            save_path=file_path, **kwargs)
            
    def generate_videos_multi_prompts(self, do_level_1=True, **kwargs):
        # level 2 and level 3 samples
        handle_indexs = []
        for non_root in self.non_roots:
            sample_index = self.sampler.get_sample_index(non_root=non_root, repeat=self.sample_n)
            false_index, true_index = sample_index
            handle_indexs += false_index + true_index
        group_indexs = self.sampler.get_sample_index_level_2()
        handle_indexs += group_indexs["sample_id"].tolist()
        handle_indexs = sorted(list(set(handle_indexs)))
        prompts = []
        save_paths = []
        for sample_id in handle_indexs:
            row = self.sampler.read_sample(sample_id=sample_id)
            if pd.isna(row["prompt"]):
                raise ValueError(f"Prompt not exist for scenario {self.scenario_index} and sample_id {sample_id}.")
            video_name = get_video_name(scenario_id=self.scenario_index, sample_id=sample_id)
            file_path = self.scenario_path / video_name
            if file_path.exists():
                continue
            prompts.append(row["prompt"])
            save_paths.append(str(file_path))
        # level 1 samples
        if do_level_1:
            sample_ids = self.sampler.get_sample_indexs_text_consistency()
            for sample_id in sample_ids:
                row = self.sampler.read_sample(sample_id=sample_id, select_sample="text")
                if pd.isna(row["prompt"]):
                    raise ValueError(f"Prompt not exist for scenario {self.scenario_index} and sample_id {sample_id}.")
                video_name = get_video_name(scenario_id=self.scenario_index, sample_id=sample_id, select_sample="text")
                file_path = self.scenario_path / video_name
                if file_path.exists():
                    continue
                prompts.append(row["prompt"])
                save_paths.append(str(file_path))
        # 保存为 JSON，避免换行、转义字符和标点带来的解析问题
        import json
        prompts_file_path = self.scenario_path / f"prompts_{self.scenario_index}.json"
        save_paths_path = self.scenario_path / f"save_paths_{self.scenario_index}.json"
        with open(prompts_file_path, "w", encoding="utf-8") as prompts_file:
            json.dump(prompts, prompts_file, ensure_ascii=False, indent=2)
        prompts_file_name = Path(prompts_file_path).name
        copied_prompts_file_name = self.prompts_path / prompts_file_name
        shutil.copy(src=prompts_file_path, dst=copied_prompts_file_name)
        with open(save_paths_path, "w", encoding="utf-8") as save_paths_file:
            json.dump(save_paths, save_paths_file, ensure_ascii=False, indent=2)
        self.video_generator.generate_video_from_multi_prompts(prompts_file=prompts_file_path,
                                                               save_paths=save_paths, **kwargs)
    
    def ask_videos(self, **kwargs):
        # level 2 and level 3 samples
        handle_indexs = []
        for non_root in self.non_roots:
            sample_index = self.sampler.get_sample_index(non_root=non_root, repeat=self.sample_n)
            false_index, true_index = sample_index
            handle_indexs += false_index + true_index
        group_indexs = self.sampler.get_sample_index_level_2()
        handle_indexs += group_indexs["sample_id"].tolist()
        handle_indexs = sorted(list(set(handle_indexs)))
        handle_indexs = [(ind, "rule") for ind in handle_indexs]
        # level 1 samples
        handle_indexs += [(ind, "text") for ind in self.sampler.get_sample_indexs_text_consistency()]
        for sample_id, select_sample in handle_indexs:
            row = self.sampler.read_sample(sample_id=sample_id, select_sample=select_sample)
            if all(not pd.isna(row["observed_" + name]) for name in self.variables):
                continue
            video_name = get_video_name(scenario_id=self.scenario_index, sample_id=sample_id, select_sample=select_sample)
            video_path = self.scenario_path / video_name
            ask_result = self.video_asker.ask_for_variables(video_path=video_path,
                                                            variable_list=self.variables,
                                                            sample_id=sample_id,
                                                            select_sample=select_sample, **kwargs)
            for name, value in ask_result.items():
                self.sampler.update_sample(sample_id=sample_id, key="observed_" + name,
                                            value = value, select_sample=select_sample)
    
    def process(self, multi=True):
        self.generate_prompts()
        if multi:
            self.generate_videos_multi_prompts()
        else:
            self.generate_videos()
        self.ask_videos()
            