import os
import json
import torch
import numpy as np
import ast
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset
from datasets import accuracy as general_accuracy


class ProgramDataset(Dataset):
    def __init__(self, split, data_path="", input_type='image', image_transforms=None,
                 fps=30, max_num_frames=30, max_samples=None, start_sample=0, **kwargs):
        """
        Args:
            split (str): 当前使用的数据子集（例如 "train" / "val"），这里用于 JSON 文件名约定
            data_path (str): 数据所在根目录
            input_type (str): "image" / "video"
            image_transforms (callable, optional): 图像处理函数
            fps (int): 视频帧率（保留原参数以兼容）
            max_num_frames (int): 最多读取帧数
            max_samples (int): 限定读取样本数量
            start_sample (int): 从第几个样本开始（适合调试用）
        """
        self.split = split
        self.data_path = Path(data_path)
        self.input_type = input_type
        self.image_transforms = image_transforms
        self.fps = fps
        self.max_num_frames = max_num_frames

        json_path = self.data_path / f"results.json"
        with open(json_path, 'r', encoding='utf-8') as f:
            raw = json.load(f)
            self.data = raw["data"]

        if max_samples is not None:
            self.data = self.data[start_sample:start_sample + max_samples]

        self.n_samples = len(self.data)

    def get_sample_path(self, index):
        sample_name = self.df.iloc[index][f"{self.input_type}_name"]
        sample_path = self.data_path / f"{self.input_type}s" / sample_name
        return sample_path

    def get_image(self, image_path):
        with open(image_path, "rb") as f:
            pil_image = Image.open(f).convert("RGB")
        if self.image_transforms:
            image = self.image_transforms(pil_image)[:3]
        else:
            image = pil_image
        return image

    def __getitem__(self, index):
        sample = self.data[index]

        # 平展字段（合并 origin + extend + 你可以自己再组合其他字段）
        out_dict = {
            "index": index,
            "query_type": self.input_type,
            "sample_id": sample["origin_data"]["source_id"],
            "origin_program": sample["origin_data"]["program"],
            "origin_question": sample["origin_data"]["question"],
            "origin_image_path": sample["origin_data"]["image"],
            "origin_answer": sample["origin_data"].get("golden_answer", ""),
            "extend_program": sample.get("extend_data_1", {}).get("program", ""),
            "extend_question": sample.get("extend_data_1", {}).get("question", ""),
            "extend_answer": sample.get("extend_data_1", {}).get("program_answer", ""),
        }
        image_path = sample["origin_data"]["image"]
        try:
            image_paths = ast.literal_eval(image_path) if image_path.startswith("[") else image_path
        except Exception:
            image_paths = image_path
        # 图像加载：默认为 origin_data 中的图像
        # image_path = sample["origin_data"]["image"].replace("['", "").replace("']", "").replace("[\"", "").replace("\"]", "")


        if isinstance(image_paths, list):
            images = [self.get_image(p) for p in image_paths]
        else:
            images = [self.get_image(image_paths)]

        out_dict["image"] = images  # 保持为 list

        return out_dict

    def __len__(self):
        return len(self.data)

    @classmethod
    def accuracy(cls, *args, **kwargs):
        return general_accuracy(*args, **kwargs)
