from util.interface import CausalModel
from util.scenario import Scenario
from sample.sampler import Sampler
from pathlib import Path
import pandas as pd
import json
from collections import defaultdict

class VBenchResultAsker:
    def __init__(self, model: CausalModel, llm_name: str):
        self.model = model
        self.llm_name = llm_name
        
        self.scenario = self.model.scenario
        self.scenario_id = Scenario.get_index(scenario=self.scenario)
        self.sampler = Sampler(model=model, llm_name=llm_name)

        self.llm_name = llm_name
        self.database_path = Path(__file__).resolve().parent.parent / 'database' / llm_name
        scenario_index = Scenario.get_index(self.model.scenario)
        self.scenario_path = self.database_path / str(scenario_index)
        self.scenario_index = scenario_index
        self.vbench_scenario_path = Path(__file__).resolve().parent.parent / "vbench_output" / llm_name / str(scenario_index)

    def get_video_paths(self) -> dict[int, str]:
        evaluate_result_file = self.scenario_path / "evaluate_results.csv"
        evaluate_result = pd.read_csv(evaluate_result_file, index_col=0)
        video_paths = evaluate_result["video_path"].to_dict()
        return video_paths
    
    def get_vbench_result_paths(self, video_paths: dict[int, str]):
        vbench_result_paths = {}
        for sample_id, video_rpath in video_paths.items():
            video_rfolder = Path(video_rpath).parent
            result_folder = self.vbench_scenario_path / video_rfolder
            result_path = result_folder / "results_eval_results.json"
            vbench_result_paths[sample_id] = result_path
        return vbench_result_paths
    
    def analysis_result(self, video_path: Path, result_path: Path) -> dict[str, float]:
        video_name = Path(video_path).name
        with open(result_path, "r") as fp:
            vbench_result = json.load(fp)
        result = {}
        for metric, result_items in vbench_result.items():
            for result_item in result_items:
                if Path(result_item["video_path"]).name == video_name:
                    result[metric] = float(result_item["video_results"])
                    break
        return result
    
    def get_vbench_results(self) -> dict[int, dict[str, float]]:
        video_paths = self.get_video_paths()
        vbench_result_paths = self.get_vbench_result_paths(video_paths=video_paths)
        vbench_results = {}
        for sample_id in video_paths:
            video_path = video_paths[sample_id]
            vbench_result_path = vbench_result_paths[sample_id]
            vbench_result = self.analysis_result(video_path=video_path, result_path=vbench_result_path)
            vbench_results[sample_id] = vbench_result
        return vbench_results
    
    def save_vbench_results(self, vbench_results):
        sample_results_folder = self.scenario_path / "sample_results"
        result_dict = defaultdict(dict)
        for sample_id in vbench_results:
            sample_results_file = sample_results_folder / f"{sample_id}.json"
            with open(sample_results_file, "r") as fp:
                sample_results = json.load(fp)
            for metric, value in vbench_results[sample_id]:
                sample_results[metric] = value
                result_dict[metric][sample_id] = value
            with open(sample_results_file, "w") as fp:
                json.dump(sample_results, fp, indent=4)
        evaluate_results_file = self.scenario_path / "evaluate_results.csv"
        evaluate_results = pd.read_csv(evaluate_results_file, index_col=0)
        for metric, value_dict in result_dict.items():
            evaluate_results[metric] = value_dict
        evaluate_results.to_csv(evaluate_results_file)
            


