import os
import torch
import numpy as np
from torch.utils.data import Dataset
import cv2
from torchvision import transforms
import random

class frame_interpolation_dataset(Dataset):
    def __init__(self, data_path, keyframe_path, downsample=1, data_size=None):
        self.data_path = data_path
        self.keyframe_path = keyframe_path
        self.down_sample = downsample

        self.video_txt = os.path.join(keyframe_path, 'videos.txt')
        self.idx_txt = os.path.join(keyframe_path, 'idx.txt')
        self.videos = []
        self.idx = []
        with open(self.video_txt, 'r') as f:
            for line in f:
                self.videos.append(line.strip())
        with open(self.idx_txt, 'r') as f:
            for line in f:
                idx = [int(i) for i in line.strip().split(' ')]
                self.idx.append(idx)

        self.data = []
        ## data: [keyframe1_path, keyframe2_path, middle_frame_path, time_interval \in [0, 1]]
        for i in range(len(self.videos)):
            video_name = self.videos[i].split('/')[-1].split('.')[0]
            idx = self.idx[i]
            for j in range(len(idx) - 1):
                keyframe1_path = os.path.join(self.data_path, f"{video_name}_frame_{idx[j]}.png")
                keyframe2_path = os.path.join(self.data_path, f"{video_name}_frame_{idx[j + 1]}.png")
                for k in range(idx[j] + 1, idx[j + 1]):
                    middle_frame_path = os.path.join(self.data_path, f"{video_name}_frame_{k}.png")
                    time_interval = (k - idx[j]) / (idx[j + 1] - idx[j])
                    if time_interval < 0.32 or time_interval > 0.68:
                        continue
                    self.data.append([keyframe1_path, keyframe2_path, middle_frame_path, time_interval])
        random.shuffle(self.data)
        if data_size is not None:
            self.data = self.data[:data_size]
        # print(self.data[:5])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        keyframe1_path, keyframe2_path, middle_frame_path, time_interval = self.data[idx]
        keyframe1 = cv2.imread(keyframe1_path).astype("float32") / 255.0
        keyframe2 = cv2.imread(keyframe2_path).astype("float32") / 255.0
        middle_frame = cv2.imread(middle_frame_path).astype("float32") / 255.0
        keyframe1 = cv2.resize(keyframe1, (1360 // self.down_sample, 768 // self.down_sample))
        keyframe2 = cv2.resize(keyframe2, (1360 // self.down_sample, 768 // self.down_sample))
        middle_frame = cv2.resize(middle_frame, (1360 // self.down_sample, 768 // self.down_sample))

        # print(keyframe1.shape, keyframe2.shape, middle_frame.shape, time_interval)
        keyframe1 = cv2.cvtColor(keyframe1, cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
        keyframe2 = cv2.cvtColor(keyframe2, cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
        middle_frame = cv2.cvtColor(middle_frame, cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
        time_interval = torch.tensor(time_interval, dtype=torch.float32)
        assert keyframe1.shape == keyframe2.shape == middle_frame.shape
        assert np.min(keyframe1 >= 0) and np.max(keyframe1 <= 1)
        assert np.min(keyframe2 >= 0) and np.max(keyframe2 <= 1)
        assert np.min(middle_frame >= 0) and np.max(middle_frame <= 1)
        assert time_interval >= 0 and time_interval <= 1

        return keyframe1, keyframe2, middle_frame, time_interval    
    

if __name__ == '__main__':
    dataset = frame_interpolation_dataset(data_path="/path to libero_dataset/finetune_dataset/libero_10_picture",
                                           keyframe_path="/path to libero_dataset/finetune_dataset/libero_10_rpd17"
                                        )
    
    for i in range(5):
        keyframe1, keyframe2, middle_frame, time_interval = dataset[i]
        print(keyframe1.shape, keyframe2.shape, middle_frame.shape, time_interval)
        # break