from pathlib import Path
import ast
import decord
from decord import cpu, gpu
import numpy as np
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset

from datasets import accuracy as general_accuracy


class MyDataset(Dataset):
    def __init__(self, split, data_path="", input_type='image', image_transforms=None, fps=30, max_num_frames=30,
                 max_samples=None, start_sample=0, **kwargs):
        """
        Args:
            split (str): Data split.
            data_path (str): Path to the data folder
            input_type (str): Type of input. One of ["image", "video"]
            image_transforms (callable, optional): Optional transform to be applied on an image. Only used if input_type
                is "image".
            fps (int): Frames per second. Only used if input_type is "video".
            max_num_frames (int): Maximum number of frames to use. Only used if input_type is "video".
            max_samples (int, optional): Maximum number of samples to load. If None, load all samples.
            start_sample (int, optional): Index of the first sample to load. If None, start from the beginning.
        """

        self.split = split
        self.data_path = Path(data_path)
        self.input_type = input_type
        self.image_transforms = image_transforms
        self.fps = fps
        self.max_num_frames = max_num_frames

        # Load questions, answers, and image ids
        with open(self.data_path / self.split / 'queries.csv', 'r') as f:
            # The csv has the rows [query, answer, image_name or video_name]
            self.df = pd.read_csv(f, index_col=None, keep_default_na=False)

        if max_samples is not None:
            self.df = self.df.iloc[start_sample:start_sample + max_samples]

        self.n_samples = len(self.df)

    def get_sample_path(self, index):
        sample_name = self.df.iloc[index][f"{self.input_type}_name"]
        if sample_name.strip().startswith("[") and sample_name.strip().endswith("]"):
            sample_name = ast.literal_eval(sample_name)
        else:
            sample_name = sample_name  # 直接当作字符串处理
        sample_path=[]
        if isinstance(sample_name, list):
            for name in sample_name:
                sample_path.append(self.data_path / f"{self.input_type}s" / name)
        else:
            sample_path = [self.data_path / f"{self.input_type}s" / sample_name]
        return [str(r) for r in sample_path]

    def get_image(self, image_path):
        with open(image_path, "rb") as f:
            pil_image = Image.open(f).convert("RGB")
        if self.image_transforms:
            image = self.image_transforms(pil_image)[:3]
        else:
            image = pil_image
        return image

    def get_video(self, video_path):
        # If fixed width and height are required, VideoReader takes width and height as arguments.
        video_reader = decord.VideoReader(str(video_path), num_threads=1, ctx=cpu(0))
        decord.bridge.set_bridge('torch')
        vlen = len(video_reader)
        original_fps = video_reader.get_avg_fps()
        num_frames = int(vlen * self.fps / original_fps)
        num_frames = min(self.max_num_frames, num_frames)
        frame_idxs = np.linspace(0, vlen, num_frames, endpoint=False).astype(np.int)
        video = video_reader.get_batch(frame_idxs).byte()
        video = video.permute(0, 3, 1, 2)
        return video

    def __getitem__(self, index):
        # original
        out_dict = self.df.iloc[index].to_dict()
        out_dict["index"] = index

        # new
        # 获取图像名字段（可能是字符串或列表字符串）
        raw_image_data = self.df.iloc[index]["image_name"]

        # 打印原始 image_name
        print("👀 raw_image_data:", raw_image_data)

        try:
            # 如果以 [ 开头且以 ] 结尾，说明是个列表字符串
            if raw_image_data.strip().startswith("[") and raw_image_data.strip().endswith("]"):
                image_names = ast.literal_eval(raw_image_data)
            else:
                image_names = raw_image_data  # 直接当作字符串处理
        except Exception as e:
            raise ValueError(f"Failed to parse image_name: {raw_image_data}") from e

        # 判断是否是 list（多图）还是单图
        if isinstance(image_names, list):
            image_paths = [self.data_path / f"{self.input_type}s" / name for name in image_names]
            images = [self.get_image(p) for p in image_paths]
        else:
            image_paths = [self.data_path / f"{self.input_type}s" / image_names]
            images = [self.get_image(image_paths[0])]

        out_dict["image"] = images  # 保持统一返回 list

        if 'extra_context' not in out_dict:
            out_dict['extra_context'] = ''

        return out_dict

    def __len__(self):
        return self.n_samples

    @classmethod
    def accuracy(cls, *args, **kwargs):
        return general_accuracy(*args, **kwargs)

if __name__ == "__main__":
    dataset = MyDataset(
        data_path="Datasets",  #  替换为你的真实路径
        csv_path="Datasets/queries.csv",
        input_type="image"  # 或 video，看你设置
    )

    # 测试几个样本
    for i in range(2):
        sample = dataset[i]
        print(f"\nSample {i}")
        print("Query:", sample["query"])
        print("Image type:", type(sample["image"]))
        print("Number of images:", len(sample["image"]))

        # 保存每张图像看看
        for j, img in enumerate(sample["image"]):
            if isinstance(img, torch.Tensor):
                save_image(img, f"debug_image_{i}_{j}.png")