import decord
from decord import VideoReader
from decord import cpu, gpu
import glob
import os.path as osp
import numpy as np
import torch, torchvision
from tqdm import tqdm
import cv2
import torchvision.transforms as T
from functools import lru_cache
from PIL import Image
import random
import copy
import os
import skvideo
import skvideo.io

random.seed(42)

decord.bridge.set_bridge("torch")

# todo 等间隔抽8帧
class FusionDataset2(torch.utils.data.Dataset):
    def __init__(self, opt):

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.phase = opt["phase"]
        # self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        # self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8

        with open(self.ann_file, "r") as fin:
            for line in fin:
                line_split = line.strip().split(",")
                filename, _, _, label,_, _,_ = line_split
                label = float(label)
                filename = osp.join(self.data_prefix, filename)
                self.video_infos.append(dict(filename=filename, label=label))

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])


    def preprocess_video(self, video_path):
        # Load the video
        vr = decord.VideoReader(video_path)
        total_frames = len(vr)
        # Uniformly sample frame indices
        frame_indices = [int(i * (total_frames / self.num_frames)) for i in range(self.num_frames)]
        frames = vr.get_batch(frame_indices)
        if isinstance(frames, torch.Tensor):
            frames = frames.cpu().numpy()  # 将 Tensor 转换为 numpy 数组
        # Preprocess frames
        preprocessed_frames = [self.transform(Image.fromarray(frame)) for frame in frames]
        return torch.stack(preprocessed_frames, dim=0)

    def __getitem__(self, index):
        video_info = self.video_infos[index]
        filename = video_info["filename"]
        label = video_info["label"]

        ## Read Original Frames
        ## Process Frames
        data = {}
        video = self.preprocess_video(filename)
        # print(video.shape)
        data["video"] = video
        data["gt_label"] = label
        data["name"] = osp.basename(video_info["filename"])

        return data

    def __len__(self):
        return len(self.video_infos)

        data = []
        with open(txt_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    filename, label = parts[0], float(parts[1])
                    data.append((filename, label))

# todo 分段后随机抽取8帧
class FusionDataset4(torch.utils.data.Dataset):
    def __init__(self, opt):

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.phase = opt["phase"]
        # self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        # self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8

        with open(self.ann_file, "r") as fin:
            for line in fin:
                line_split = line.strip().split(",")
                filename, _, _, label,_, _,_ = line_split
                label = float(label)
                filename = osp.join(self.data_prefix, filename)
                self.video_infos.append(dict(filename=filename, label=label))

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])

    def preprocess_video(self, video_path):
        # 加载视频
        vr = decord.VideoReader(video_path)
        total_frames = len(vr)

        # 确保分段数不超过总帧数
        if self.num_frames > total_frames:
            self.num_frames = total_frames

        # 计算每段的帧数
        segment_size = total_frames / self.num_frames

        # 每段随机选择一帧
        frame_indices = []
        for i in range(self.num_frames):
            # 计算当前段的起始和结束帧索引
            start_idx = int(i * segment_size)
            end_idx = int((i + 1) * segment_size)

            # 处理最后一段可能不足segment_size的情况
            if end_idx > total_frames:
                end_idx = total_frames
            if start_idx >= end_idx:
                break  # 防止空段

            # 在段内随机选择一帧
            random_frame_idx = random.randint(start_idx, end_idx - 1)
            frame_indices.append(random_frame_idx)

        # 提取帧
        frames = vr.get_batch(frame_indices)
        if isinstance(frames, torch.Tensor):
            frames = frames.cpu().numpy()  # 将 Tensor 转换为 numpy 数组

        # 预处理帧
        preprocessed_frames = [self.transform(Image.fromarray(frame)) for frame in frames]
        return torch.stack(preprocessed_frames, dim=0)

    def __getitem__(self, index):
        video_info = self.video_infos[index]
        filename = video_info["filename"]
        label = video_info["label"]

        ## Read Original Frames
        ## Process Frames
        data = {}
        video = self.preprocess_video(filename)
        # print(video.shape)
        data["video"] = video
        data["gt_label"] = label
        data["name"] = osp.basename(video_info["filename"])

        return data

    def __len__(self):
        return len(self.video_infos)

        data = []
        with open(txt_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    filename, label = parts[0], float(parts[1])
                    data.append((filename, label))

# todo: 为LIVE-Qualcomm单独实现
class FusionDataset3(torch.utils.data.Dataset):
    def __init__(self, opt):
        ## opt is a dictionary that includes options for video sampling

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.opt = opt
        self.data_backend = opt.get("data_backend", "disk")
        self.augment = opt.get("augment", False)

        self.phase = opt["phase"]
        self.crop = opt.get("random_crop", False)
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8


        with open(self.ann_file, "r") as fin:
            for line in fin:
                line_split = line.strip().split(",")
                filename, _, _, label,_,_,_ = line_split
                label = float(label)
                filename = filename[:-4]
                filename = osp.join(self.data_prefix, filename)
                # print("filename:", filename)
                self.video_infos.append(dict(filename=filename, label=label))



        #print("Refreshed sample hyper-paremeters:", self.sample_types)

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])


    def preprocess_video(self, folder_path):
        """
        从包含png图片的文件夹中读取8张图片并进行预处理

        参数:
        folder_path: 包含8张JPG图片的文件夹路径

        返回:
        处理后的图片张量，维度为 [8, C, H, W]
        """
        # 获取文件夹中所有JPG图片并按文件名排序
        image_files = sorted(glob.glob(os.path.join(folder_path, "*.png")))

        # 验证图片数量
        if len(image_files) != self.num_frames:
            raise ValueError(f"文件夹 {folder_path} 必须包含{self.num_frames}张png图片，当前数量: {len(image_files)}")

        # 加载并预处理图片
        preprocessed_frames = []
        for img_path in image_files:
            # 打开图片并转为RGB格式（防止灰度图）
            # print(img_path)
            img = Image.open(img_path)
            # 应用预处理转换（如Resize、ToTensor等）
            preprocessed = self.transform(img)
            preprocessed_frames.append(preprocessed)

        # 堆叠为PyTorch张量，维度: [8, C, H, W]
        return torch.stack(preprocessed_frames, dim=0)


    def __getitem__(self, index):
        video_info = self.video_infos[index]
        filename = video_info["filename"]
        label = video_info["label"]


        ## Read Original Frames
        ## Process Frames
        data={}
        video = self.preprocess_video(filename)
        # print(video.shape)
        data["video"] = video
        data["gt_label"] = label
        data["name"] = osp.basename(video_info["filename"])

        return data

    def __len__(self):
        return len(self.video_infos)


# todo： 读八张图片
class FusionDataset(torch.utils.data.Dataset):
    def __init__(self, opt):
        ## opt is a dictionary that includes options for video sampling

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.opt = opt
        self.data_backend = opt.get("data_backend", "disk")
        self.augment = opt.get("augment", False)

        self.phase = opt["phase"]
        self.crop = opt.get("random_crop", False)
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8


        with open(self.ann_file, "r") as fin:
            for line in fin:
                line_split = line.strip().split(",")
                filename, _, _, label = line_split
                label = float(label)
                filename = filename[:-4]
                filename = osp.join(self.data_prefix, filename)
                # print("filename:", filename)
                self.video_infos.append(dict(filename=filename, label=label))



        #print("Refreshed sample hyper-paremeters:", self.sample_types)

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])


    from PIL import Image
    import torch



#todo 随机开始，从每个片段中等间隔抽一帧

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, opt):
        ## opt is a dictionary that includes options for video sampling

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.opt = opt
        self.data_backend = opt.get("data_backend", "disk")
        self.augment = opt.get("augment", False)

        self.phase = opt["phase"]
        self.crop = opt.get("random_crop", False)
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8

        if isinstance(self.ann_file, list):
            self.video_infos = self.ann_file
        else:
            with open(self.ann_file, "r") as fin:
                for line in fin:
                    line_split = line.strip().split(",")
                    filename, _, _, label,_,_,_ = line_split
                    label = float(label)
                    filename = osp.join(self.data_prefix, filename)
                    self.video_infos.append(dict(filename=filename, label=label))

        # print("Refreshed sample hyper-paremeters:", self.sample_types)

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])

    def preprocess_video(self, video_path='001BB.mp4'):
        # 加载视频
        vr = decord.VideoReader(video_path)
        total_frames = len(vr)
        # print("视频的总帧数为:", total_frames)
        # 确保视频帧数足够
        if total_frames < self.num_frames:
            # 处理短视频：重复最后一帧
            frame_indices = [i for i in range(total_frames)]
            frame_indices.extend([total_frames - 1] * (self.num_frames - total_frames))
        else:
            # 初始化帧索引列表
            frame_indices = []

            # 计算每个片段的长度
            segment_length = total_frames // self.num_frames
            # print("每个片段的长度:", segment_length)
            # 从第一个片段中随机选择一帧
            start_idx = 0
            end_idx = segment_length
            first_frame = random.randint(start_idx, end_idx - 1)
            # print("first_frame:", first_frame)
            frame_indices.append(first_frame)

            # 等间隔地从后面的片段中抽帧
            for i in range(1, self.num_frames):
                mid_idx = i * segment_length + first_frame
                frame_indices.append(mid_idx)

        # print("抽取的帧为：", frame_indices)
        # 获取帧
        frames = vr.get_batch(frame_indices)
        if isinstance(frames, torch.Tensor):
            frames = frames.cpu().numpy()  # 将 Tensor 转换为 numpy 数组
        # Preprocess frames
        preprocessed_frames = [self.transform(Image.fromarray(frame)) for frame in frames]
        return torch.stack(preprocessed_frames, dim=0)



    def __getitem__(self, index):
        video_info = self.video_infos[index]
        filename = video_info["filename"]
        label = video_info["label"]

        ## Read Original Frames
        ## Process Frames
        data = {}
        video = self.preprocess_video(filename)

        data["video"] = video
        data["gt_label"] = label
        data["name"] = osp.basename(video_info["filename"])

        return data

    def __len__(self):
        return len(self.video_infos)

#todo 读8张图片
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, opt):
        ## opt is a dictionary that includes options for video sampling

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.opt = opt
        self.data_backend = opt.get("data_backend", "disk")
        self.augment = opt.get("augment", False)

        self.phase = opt["phase"]
        self.crop = opt.get("random_crop", False)
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8


        with open(self.ann_file, "r") as fin:
            for line in fin:
                line_split = line.strip().split(",")
                filename, _, _, label = line_split
                label = float(label)
                filename = filename[:-4]
                filename = osp.join(self.data_prefix, filename)
                # print("filename:", filename)
                self.video_infos.append(dict(filename=filename, label=label))



        #print("Refreshed sample hyper-paremeters:", self.sample_types)

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])


    def preprocess_video(self, folder_path):
        """
        从包含JPG图片的文件夹中读取8张图片并进行预处理

        参数:
        folder_path: 包含8张JPG图片的文件夹路径

        返回:
        处理后的图片张量，维度为 [8, C, H, W]
        """
        # 获取文件夹中所有JPG图片并按文件名排序
        image_files = sorted(glob.glob(os.path.join(folder_path, "*.png")))

        # 验证图片数量
        if len(image_files) != self.num_frames:
            raise ValueError(f"文件夹 {folder_path} 必须包含{self.num_frames}张png图片，当前数量: {len(image_files)}")

        # 加载并预处理图片
        preprocessed_frames = []
        for img_path in image_files:
            # 打开图片并转为RGB格式（防止灰度图）
            # print(img_path)
            img = Image.open(img_path)
            # 应用预处理转换（如Resize、ToTensor等）
            preprocessed = self.transform(img)
            preprocessed_frames.append(preprocessed)

        # 堆叠为PyTorch张量，维度: [8, C, H, W]
        return torch.stack(preprocessed_frames, dim=0)


    def __getitem__(self, index):
        video_info = self.video_infos[index]
        filename = video_info["filename"]
        label = video_info["label"]


        ## Read Original Frames
        ## Process Frames
        data={}
        video = self.preprocess_video(filename)
        # print(video.shape)
        data["video"] = video
        data["gt_label"] = label
        data["name"] = osp.basename(video_info["filename"])

        return data

    def __len__(self):
        return len(self.video_infos)


#测试时增强，取均匀抽帧、等间隔片段均值，等间隔片段中位数，MSE排序后均匀抽帧
class Test4Dataset(torch.utils.data.Dataset):
    def __init__(self, opt):
        ## opt is a dictionary that includes options for video sampling

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.opt = opt
        self.data_backend = opt.get("data_backend", "disk")
        self.augment = opt.get("augment", False)

        self.phase = opt["phase"]
        self.crop = opt.get("random_crop", False)
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8


        with open(self.ann_file, "r") as fin:
            for line in fin:
                line_split = line.strip().split(",")
                filename, _, _, label, frame1, frame2, frame3 = line_split
                label = float(label)
                # filename = filename[:-4]
                filename = osp.join(self.data_prefix, filename)
                frame1 = [int(i) for i in frame1.strip().split("-")]
                frame2 = [int(i) for i in frame2.strip().split("-")]
                frame3 = [int(i) for i in frame3.strip().split("-")]
                # print("filename:", filename)
                self.video_infos.append(dict(filename=filename, label=label, frame1=frame1, frame2=frame2,frame3=frame3))



        #print("Refreshed sample hyper-paremeters:", self.sample_types)

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])


    def preprocess_video(self, video_path='001BB.mp4',frame1 =[],frame2=[],frame3=[]):
        # 加载视频
        vr = decord.VideoReader(video_path)

        # todo 等间隔取
        total_frames = len(vr)
        # print("视频的总帧数为:", total_frames)
        # 初始化帧索引列表
        frame_indices = []

        # 计算每个片段的长度
        segment_length = total_frames // self.num_frames
        # print("每个片段的长度:", segment_length)
        # 从第一个片段中随机选择一帧
        start_idx = 0
        end_idx = segment_length
        first_frame = random.randint(start_idx, end_idx - 1)
        # print("first_frame:", first_frame)
        frame_indices.append(first_frame)

        # 等间隔地从后面的片段中抽帧
        for i in range(1, self.num_frames):
            mid_idx = i * segment_length + first_frame
            frame_indices.append(mid_idx)

        # print("抽取的帧为：", frame_indices)
        # 获取帧
        frames0 = vr.get_batch(frame_indices)
        if isinstance(frames0, torch.Tensor):
            frames0 = frames0.cpu().numpy()  # 将 Tensor 转换为 numpy 数组
        # Preprocess frames
        preprocessed_frames0 = [self.transform(Image.fromarray(frame)) for frame in frames0]
        sample0 = torch.stack(preprocessed_frames0, dim=0)

        # todo 按MSE排序后均匀抽取8帧
        frames1 = vr.get_batch(frame1)
        if isinstance(frames1, torch.Tensor):
            frames1 = frames1.cpu().numpy()
        preprocessed_frames1 = [self.transform(Image.fromarray(frame)) for frame in frames1]
        sample1 = torch.stack(preprocessed_frames1, dim=0)

        # todo 等间隔分8段，取每段中MSE最接近均值的帧
        frames2 = vr.get_batch(frame2)
        if isinstance(frames2, torch.Tensor):
            frames2 = frames2.cpu().numpy()
        preprocessed_frames2 = [self.transform(Image.fromarray(frame)) for frame in frames2]
        sample2 = torch.stack(preprocessed_frames2, dim=0)

        # todo 等间隔分8段，取每段中MSE处于中位数的帧
        frames3 = vr.get_batch(frame3)
        if isinstance(frames3, torch.Tensor):
            frames3 = frames3.cpu().numpy()
        preprocessed_frames3 = [self.transform(Image.fromarray(frame)) for frame in frames3]
        sample3 = torch.stack(preprocessed_frames3, dim=0)

        return torch.stack([sample0,sample1,sample2,sample3], dim=0)  # 返回 [4,8,3,336,336]的形状


    def __getitem__(self, index):
        video_info = self.video_infos[index]
        filename = video_info["filename"]
        label = video_info["label"]
        frame1 = video_info["frame1"]
        frame2 = video_info["frame2"]
        frame3 = video_info["frame3"]

        ## Read Original Frames
        ## Process Frames
        data={}
        video = self.preprocess_video(filename, frame1, frame2, frame3)
        # print(video.shape)
        data["video"] = video   #video shape:  torch.Size([4, 8, 3, 336, 336])
        # print("video shape: ", video.shape)
        data["gt_label"] = label
        data["name"] = osp.basename(video_info["filename"])

        return data

    def __len__(self):
        return len(self.video_infos)

# todo 取分段后的均值
class AVGDataset(torch.utils.data.Dataset):
    def __init__(self, opt):
        ## opt is a dictionary that includes options for video sampling

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.opt = opt
        self.data_backend = opt.get("data_backend", "disk")
        self.augment = opt.get("augment", False)

        self.phase = opt["phase"]
        self.crop = opt.get("random_crop", False)
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8


        with open(self.ann_file, "r") as fin:
            for line in fin:
                line_split = line.strip().split(",")
                filename, _, _, label, frame1, frame2, frame3 = line_split
                print(line_split)
                label = float(label)
                # filename = filename[:-4]
                filename = osp.join(self.data_prefix, filename)
                frame1 = [int(i) for i in frame1.strip().split("-")]
                frame2 = [int(i) for i in frame2.strip().split("-")]
                frame3 = [int(i) for i in frame3.strip().split("-")]
                # print("filename:", filename)
                self.video_infos.append(dict(filename=filename, label=label, frame1=frame1, frame2=frame2,frame3=frame3))



        #print("Refreshed sample hyper-paremeters:", self.sample_types)

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])


    def preprocess_video(self, video_path='001BB.mp4',frame1 =[],frame2=[],frame3=[]):
        # 加载视频
        vr = decord.VideoReader(video_path)
        #
        # # todo 等间隔取
        # total_frames = len(vr)
        # # print("视频的总帧数为:", total_frames)
        # # 初始化帧索引列表
        # frame_indices = []
        #
        # # 计算每个片段的长度
        # segment_length = total_frames // self.num_frames
        # # print("每个片段的长度:", segment_length)
        # # 从第一个片段中随机选择一帧
        # start_idx = 0
        # end_idx = segment_length
        # first_frame = random.randint(start_idx, end_idx - 1)
        # # print("first_frame:", first_frame)
        # frame_indices.append(first_frame)
        #
        # # 等间隔地从后面的片段中抽帧
        # for i in range(1, self.num_frames):
        #     mid_idx = i * segment_length + first_frame
        #     frame_indices.append(mid_idx)
        #
        # # print("抽取的帧为：", frame_indices)
        # # 获取帧
        # frames0 = vr.get_batch(frame_indices)
        # if isinstance(frames0, torch.Tensor):
        #     frames0 = frames0.cpu().numpy()  # 将 Tensor 转换为 numpy 数组
        # # Preprocess frames
        # preprocessed_frames0 = [self.transform(Image.fromarray(frame)) for frame in frames0]
        # sample0 = torch.stack(preprocessed_frames0, dim=0)
        #
        # # todo 按MSE排序后均匀抽取8帧
        # frames1 = vr.get_batch(frame1)
        # if isinstance(frames1, torch.Tensor):
        #     frames1 = frames1.cpu().numpy()
        # preprocessed_frames1 = [self.transform(Image.fromarray(frame)) for frame in frames1]
        # sample1 = torch.stack(preprocessed_frames1, dim=0)

        # todo 等间隔分8段，取每段中MSE最接近均值的帧
        frames2 = vr.get_batch(frame2)
        if isinstance(frames2, torch.Tensor):
            frames2 = frames2.cpu().numpy()
        preprocessed_frames2 = [self.transform(Image.fromarray(frame)) for frame in frames2]
        sample2 = torch.stack(preprocessed_frames2, dim=0)
        return sample2
        # # todo 等间隔分8段，取每段中MSE处于中位数的帧
        # frames3 = vr.get_batch(frame3)
        # if isinstance(frames3, torch.Tensor):
        #     frames3 = frames3.cpu().numpy()
        # preprocessed_frames3 = [self.transform(Image.fromarray(frame)) for frame in frames3]
        # sample3 = torch.stack(preprocessed_frames3, dim=0)

        # return torch.stack([sample0,sample1,sample2,sample3], dim=0)  # 返回 [4,8,3,336,336]的形状


    def __getitem__(self, index):
        video_info = self.video_infos[index]
        filename = video_info["filename"]
        label = video_info["label"]
        frame1 = video_info["frame1"]
        frame2 = video_info["frame2"]
        frame3 = video_info["frame3"]

        ## Read Original Frames
        ## Process Frames
        data={}
        video = self.preprocess_video(filename, frame1, frame2, frame3)
        # print(video.shape)
        data["video"] = video   #video shape:  torch.Size([4, 8, 3, 336, 336])
        # print("video shape: ", video.shape)
        data["gt_label"] = label
        data["name"] = osp.basename(video_info["filename"])

        return data

    def __len__(self):
        return len(self.video_infos)


# todo 取分段后的中位数
class MedianDataset(torch.utils.data.Dataset):
    def __init__(self, opt):
        ## opt is a dictionary that includes options for video sampling

        super().__init__()
        """
                初始化视频数据集类。

                参数:
                - num_frames: 每个视频采样的帧数，默认为 8。
                - center_crop: 是否进行中心裁剪，默认为 False。
                - image_size: 处理后图像的大小，默认为 336。
                - interpolation: 调整大小的插值模式，默认为双线性插值。
                """
        self.video_infos = []
        self.ann_file = opt["anno_file"]
        self.data_prefix = opt["data_prefix"]
        self.opt = opt
        self.data_backend = opt.get("data_backend", "disk")
        self.augment = opt.get("augment", False)

        self.phase = opt["phase"]
        self.crop = opt.get("random_crop", False)
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        self.samplers = {}
        self.interpolation = T.InterpolationMode.BILINEAR
        self.center_crop = False
        self.image_size = 336
        self.transform = self.get_image_transform()
        self.num_frames = 8


        with open(self.ann_file, "r") as fin:
            for line in fin:
                line_split = line.strip().split(",")
                filename, _, _, label, frame1, frame2, frame3 = line_split
                print(line_split)
                label = float(label)
                # filename = filename[:-4]
                filename = osp.join(self.data_prefix, filename)
                frame1 = [int(i) for i in frame1.strip().split("-")]
                frame2 = [int(i) for i in frame2.strip().split("-")]
                frame3 = [int(i) for i in frame3.strip().split("-")]
                # print("filename:", filename)
                self.video_infos.append(dict(filename=filename, label=label, frame1=frame1, frame2=frame2,frame3=frame3))



        #print("Refreshed sample hyper-paremeters:", self.sample_types)

    def get_image_transform(self):
        """
        获取图像预处理变换。

        返回:
        - 预处理变换组合。
        """
        if self.center_crop:
            crop = [
                T.Resize(self.image_size, interpolation=self.interpolation),
                T.CenterCrop(self.image_size)
            ]
        else:
            # "Squash": most versatile
            crop = [
                T.Resize((self.image_size, self.image_size), interpolation=self.interpolation)
            ]

        return T.Compose(crop + [
            T.Lambda(lambda x: x.convert("RGB")),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
        ])


    def preprocess_video(self, video_path='001BB.mp4',frame1 =[],frame2=[],frame3=[]):
        # 加载视频
        vr = decord.VideoReader(video_path)
        #
        # # todo 等间隔取
        # total_frames = len(vr)
        # # print("视频的总帧数为:", total_frames)
        # # 初始化帧索引列表
        # frame_indices = []
        #
        # # 计算每个片段的长度
        # segment_length = total_frames // self.num_frames
        # # print("每个片段的长度:", segment_length)
        # # 从第一个片段中随机选择一帧
        # start_idx = 0
        # end_idx = segment_length
        # first_frame = random.randint(start_idx, end_idx - 1)
        # # print("first_frame:", first_frame)
        # frame_indices.append(first_frame)
        #
        # # 等间隔地从后面的片段中抽帧
        # for i in range(1, self.num_frames):
        #     mid_idx = i * segment_length + first_frame
        #     frame_indices.append(mid_idx)
        #
        # # print("抽取的帧为：", frame_indices)
        # # 获取帧
        # frames0 = vr.get_batch(frame_indices)
        # if isinstance(frames0, torch.Tensor):
        #     frames0 = frames0.cpu().numpy()  # 将 Tensor 转换为 numpy 数组
        # # Preprocess frames
        # preprocessed_frames0 = [self.transform(Image.fromarray(frame)) for frame in frames0]
        # sample0 = torch.stack(preprocessed_frames0, dim=0)
        #
        # # todo 按MSE排序后均匀抽取8帧
        # frames1 = vr.get_batch(frame1)
        # if isinstance(frames1, torch.Tensor):
        #     frames1 = frames1.cpu().numpy()
        # preprocessed_frames1 = [self.transform(Image.fromarray(frame)) for frame in frames1]
        # sample1 = torch.stack(preprocessed_frames1, dim=0)

        # # todo 等间隔分8段，取每段中MSE最接近均值的帧
        # frames2 = vr.get_batch(frame2)
        # if isinstance(frames2, torch.Tensor):
        #     frames2 = frames2.cpu().numpy()
        # preprocessed_frames2 = [self.transform(Image.fromarray(frame)) for frame in frames2]
        # sample2 = torch.stack(preprocessed_frames2, dim=0)
        # return sample2
        # todo 等间隔分8段，取每段中MSE处于中位数的帧
        frames3 = vr.get_batch(frame3)
        if isinstance(frames3, torch.Tensor):
            frames3 = frames3.cpu().numpy()
        preprocessed_frames3 = [self.transform(Image.fromarray(frame)) for frame in frames3]
        sample3 = torch.stack(preprocessed_frames3, dim=0)
        return sample3

        # return torch.stack([sample0,sample1,sample2,sample3], dim=0)  # 返回 [4,8,3,336,336]的形状


    def __getitem__(self, index):
        video_info = self.video_infos[index]
        filename = video_info["filename"]
        label = video_info["label"]
        frame1 = video_info["frame1"]
        frame2 = video_info["frame2"]
        frame3 = video_info["frame3"]

        ## Read Original Frames
        ## Process Frames
        data={}
        video = self.preprocess_video(filename, frame1, frame2, frame3)
        # print(video.shape)
        data["video"] = video   #video shape:  torch.Size([4, 8, 3, 336, 336])
        # print("video shape: ", video.shape)
        data["gt_label"] = label
        data["name"] = osp.basename(video_info["filename"])

        return data

    def __len__(self):
        return len(self.video_infos)