from collections import defaultdict
from pathlib import Path
import json
from util.interface import VideoAsker, CausalModel
from sample.sampler import Sampler
from handle.prompt_generator import N_prompts_for_same_value, PromptGeneratorV1
from util.scenario import Scenario
import pandas as pd

class VideoAskerV1(VideoAsker):
    def __init__(self, model: CausalModel, llm_name: str):
        super().__init__()
        self.model = model
        self.llm_name = llm_name
        
        self.scenario = self.model.scenario
        self.scenario_id = Scenario.get_index(scenario=self.scenario)
        self.sampler = Sampler(model=model, llm_name=llm_name)
        self.video_index = {
            "rule": defaultdict(lambda: [0] * N_prompts_for_same_value),
            "text": defaultdict(lambda: [0] * N_prompts_for_same_value)
        }
        self.scenario_folder = Path(__file__).resolve().parent.parent / "database" / self.llm_name / str(self.scenario_id)
        self.video_folder = self.scenario_folder / "videos"
        self.prompt_generator = PromptGeneratorV1(model=model)

    def ask_for_variables(self, video_path, variable_list, **kwargs):
        sample_id = kwargs["sample_id"]
        select_sample = kwargs["select_sample"]
        read_folder, video_index = self._get_video_info(sample_id=sample_id, select_sample=select_sample)
        json_name = read_folder / f"{video_index}_answer.json"
        results = VideoAskerV1.values_from_json(json_file_name=json_name, variable_list=variable_list)
        return results
    
    def _get_video_info(self, sample_id, select_sample):
        row = self.sampler.read_sample(sample_id=sample_id, select_sample=select_sample)
        values_dict = {name: row["true_" + name] for name in self.sampler.variables}
        values = [values_dict[factor] for factor in self.prompt_generator.factors]
        values_str = "".join([str(int(value)) for value in values])
        prompt = row["prompt"]
        if select_sample == "rule":
            prompts_pool: list[str] = self.prompt_generator.samples[tuple(values)]
        else:
            prompts_pool: list[str] = self.prompt_generator.samples_full[tuple(values)]
        if prompt not in prompts_pool:
            print(f"prompt = {prompt}")
            print(f"prompts_pool = {prompts_pool}")
            raise ValueError(f"prompt not in prompts_pool. scenario_id={self.scenario_id}, sample_id={sample_id}, select_sample={select_sample}")
        prompt_index = prompts_pool.index(prompt)
        read_folder = self.video_folder / select_sample / values_str / str(prompt_index)
        video_index = self.video_index[select_sample][tuple(values)][prompt_index]
        self.video_index[select_sample][tuple(values)][prompt_index] += 1
        return read_folder, video_index
    
    def get_video_paths(self, select_sample="rule") -> dict[int, str]:
        self.video_index = {
            "rule": defaultdict(lambda: [0] * N_prompts_for_same_value),
            "text": defaultdict(lambda: [0] * N_prompts_for_same_value)
        }
        sample_df = self.sampler.get_sample_df(select_sample=select_sample)
        sample_ids = sample_df["sample_id"].tolist()
        video_paths_dict = {}
        for sample_id in sample_ids:
            read_folder, video_index = self._get_video_info(sample_id=sample_id, select_sample=select_sample)
            video_path = read_folder / f"{video_index}.mp4"
            video_relative_path = video_path.relative_to(self.scenario_folder)
            video_paths_dict[sample_id] = str(video_relative_path)
        return video_paths_dict


    @classmethod
    def values_from_json(cls, json_file_name, variable_list):
        try:
            with open(json_file_name, "r") as json_file:
                data = json.load(json_file)
        except FileNotFoundError:
            data = {name: pd.NA for name in variable_list}
        return {name: data[name] for name in variable_list}
    

class VideoAskerV2_0(VideoAsker):
    def __init__(self, model: CausalModel, llm_name: str):
        super().__init__()
        self.model = model
        self.llm_name = llm_name
        
        self.scenario = self.model.scenario
        self.scenario_id = Scenario.get_index(scenario=self.scenario)
        self.sampler = Sampler(model=model, llm_name=llm_name)
        self.video_index = {
            "rule": defaultdict(lambda: [0] * N_prompts_for_same_value),
            "text": defaultdict(lambda: [0] * N_prompts_for_same_value)
        }
        self.video_folder = Path(__file__).resolve().parent.parent / "database" / self.llm_name / str(self.scenario_id) / "videos"
        self.prompt_generator = PromptGeneratorV1(model=model)

    def ask_for_variables(self, video_path, variable_list, **kwargs):
        sample_id = kwargs["sample_id"]
        select_sample = kwargs["select_sample"]
        row = self.sampler.read_sample(sample_id=sample_id, select_sample=select_sample)
        values_dict = {name: row["true_" + name] for name in self.sampler.variables}
        values = [values_dict[factor] for factor in self.prompt_generator.factors]
        values_str = "".join([str(int(value)) for value in values])
        prompt = row["prompt"]
        if select_sample == "rule":
            prompts_pool: list[str] = self.prompt_generator.samples[tuple(values)]
        else:
            prompts_pool: list[str] = self.prompt_generator.samples_full[tuple(values)]
        if prompt not in prompts_pool:
            print(f"prompt = {prompt}")
            print(f"prompts_pool = {prompts_pool}")
            raise ValueError(f"prompt not in prompts_pool. scenario_id={self.scenario_id}, sample_id={sample_id}, select_sample={select_sample}")
        prompt_index = prompts_pool.index(prompt)
        read_folder = self.video_folder / select_sample / values_str / str(prompt_index)
        video_index = self.video_index[select_sample][tuple(values)][prompt_index]
        self.video_index[select_sample][tuple(values)][prompt_index] += 1
        json_name = read_folder / f"{video_index}_anwser.json"
        results = VideoAskerV2_0.values_from_json(json_file_name=json_name, variable_list=variable_list)
        return results

    @classmethod
    def values_from_json(cls, json_file_name, variable_list):
        with open(json_file_name, "r") as json_file:
            data = json.load(json_file)
        return {name: data[name] for name in variable_list}
