import os
import cv2
import torch
import itertools
import numpy as np

from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from sklearn.model_selection import StratifiedKFold

__all__ = [
    'VideoDataset',
]

def listdir(*path: str) -> np.ndarray:
    return np.sort(os.listdir(os.path.join(*path)))

def extractFramesInRange(video_path: str, t: Tensor, frame_legnth: int = 50) -> Tensor:
    imgs = []
    cap = cv2.VideoCapture(video_path)
    
    i = 0
    success, data = cap.read()
    frame_freq = int(cap.get(7) // frame_legnth)
    if not success:
        raise ValueError(f'Video {video_path} unreadable!')
    ith_frame = 0
    while success and len(imgs) < frame_legnth and len(imgs) < len(t):
        if i % frame_freq == 0:
            if t[len(imgs)] == ith_frame:
                imgs.append(cv2.resize(data, (224, 224)).transpose(2, 0, 1))
            ith_frame += 1
        i += 1
        success, data = cap.read()
    cap.release()

    return torch.tensor(np.stack(imgs), dtype=torch.float)

def extractFramesInDelta(video_path: str, t: Tensor, frame_legnth: int = 50) -> Tensor:
    imgs = []
    cap = cv2.VideoCapture(video_path)
    
    i = 0
    success, data = cap.read()
    frame_freq = int(cap.get(7) // frame_legnth)
    if not success:
        raise ValueError(f'Video {video_path} unreadable!')
    curr_del = 1

    while success and len(imgs) < frame_legnth and len(imgs) < len(t):
        if i % frame_freq == 0:
            if curr_del == 1:
                imgs.append(cv2.resize(data, (224, 224)).transpose(2, 0, 1))
                curr_del = t[min(len(imgs), len(t) - 1)]
            else:
                curr_del -= 1
        i += 1
        success, data = cap.read()
    cap.release()

    return torch.tensor(np.stack(imgs), dtype=torch.float)

class VideoDataset(Dataset):
    def __init__(self, data_path: str, time_dic: dict, train: bool = True, train_ratio: float = 0.8, range_form: bool = False) -> None:
        super(VideoDataset, self).__init__()
        if train_ratio <= 0 or train_ratio >= 1:
            raise ValueError('train_ratio must be in the range of (0, 1)')
        self.x, self.y = [], []
        for ith_class, folder_name in enumerate(listdir(data_path)):
            for file_name in listdir(data_path, folder_name):
                path = os.path.join(data_path, folder_name, file_name)
                self.x.append(path)
                self.y.append(ith_class)
        self.x = np.array(self.x)
        self.y = torch.tensor(self.y, dtype=torch.long)

        skf = StratifiedKFold(n_splits=int(1/(1-train_ratio)), shuffle=True, random_state=42)
        sample_index = next(itertools.islice(skf.split(self.x, self.y), 0, None))[0 if train else 1]
        self.x, self.y = self.x[sample_index], self.y[sample_index]

        self.transform = transforms.Compose([
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        self.t = [time_dic[x] for x in self.x]
        self.extractFrames = extractFramesInRange if range_form else extractFramesInDelta

    def __getitem__(self, index): 
        return self.t[index], self.transform(self.extractFrames(self.x[index], self.t[index])), self.y[index]

    def __len__(self): 
        return len(self.y)
