import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import numpy as np
import cv2
from torchvision import transforms


class frame_interp_dataset(Dataset):
    def __init__(self, data_path, is_test=False):
        self.data_path = data_path
        self.video_txt = os.path.join(data_path, 'videos.txt')
        self.idx_txt = os.path.join(data_path, 'idx.txt')
        self.videos = []
        self.idx = []
        self.is_test = is_test
        if not is_test:
            with open(self.video_txt, 'r') as f:
                for line in f:
                    self.videos.append(line.strip())
            with open(self.idx_txt, 'r') as f:
                for line in f:
                    idx = [int(i) for i in line.strip().split(' ')]
                    self.idx.append(idx)
        else:
            self.videos = os.listdir(data_path)
            self.videos = [vid for vid in self.videos if vid.endswith('.mp4')]
            self.idx = [0 for _ in range(len(self.videos))]
        assert len(self.videos) == len(self.idx)
        # standard transform for resnet50
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])


    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        relative_video_path = self.videos[idx]
        idx = self.idx[idx]
        # load video
        video_path = os.path.join(self.data_path, relative_video_path)
        vids = []
        video = cv2.VideoCapture(video_path)
        while video.isOpened():
            ret, frame = video.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = self.transform(frame)
                vids.append(frame)
            else:
                break
        video.release()
        video_tensor = torch.stack(vids, dim=0)
        idx_tensor = torch.from_numpy(np.array(idx)).float()
        return video_tensor, idx_tensor, video_path
    

class embedding_dataset(Dataset):
    def __init__(self, data_path, embedding_path):
        self.embedding_path = embedding_path
        self.video_txt = os.path.join(data_path, 'videos.txt')
        self.idx_txt = os.path.join(data_path, 'idx.txt')
        self.videos = []
        self.idx = []
        with open(self.video_txt, 'r') as f:
            for line in f:
                self.videos.append(line.strip())
        
        with open(self.idx_txt, 'r') as f:
            for line in f:
                idx = [int(i) for i in line.strip().split(' ')]
                self.idx.append(idx)
        self.embedding_files = []
        for video in self.videos:
            vid_name = video.split('/')[-1].split('.')[0]
            self.embedding_files.append(vid_name + '.pt')

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, idx):
        embedding = torch.load(os.path.join(self.embedding_path, self.embedding_files[idx]))
        idxs = self.idx[idx]
        file_name = self.embedding_files[idx]
        idx_tensor = torch.from_numpy(np.array(idxs)).float()
        return embedding, idx_tensor, file_name
    


if __name__ == '__main__':

    data_path = '/path to libero_dataset/finetune_dataset/libero_90_rpd17'
    cache_path = '/path to libero_dataset/interpolation_model/cache/libero_90_rpd17'
    dataset = embedding_dataset(data_path, cache_path)
    dataloader = DataLoader(dataset, batch_size=3, shuffle=True)
    for i, (embedding, idx, file_name) in enumerate(dataloader):
        print(embedding.shape)
        print(idx.shape)
        print(idx[:,1:] - idx[:,:-1])
        print(file_name)
        break