import os
import cv2
import torch
import itertools
import numpy as np

from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from sklearn.model_selection import StratifiedKFold

__all__ = [
    'VideoDataset',
    'CrossValDataset'
]

def listdir(*path: str) -> np.ndarray:
    return np.sort(os.listdir(os.path.join(*path)))

def extractFramesByLength(video_path: str, frame_legnth: int = 50) -> Tensor:
    imgs = []
    cap = cv2.VideoCapture(video_path)
    
    i = 0
    success, data = cap.read()
    frame_freq = int(cap.get(7) // frame_legnth)
    if not success:
        raise ValueError(f'Video {video_path} unreadable!')
    while success and len(imgs) < frame_legnth:
        if i % frame_freq == 0:
            imgs.append(cv2.resize(data, (224, 224)).transpose(2, 0, 1))
        i += 1
        success, data = cap.read()
    cap.release()

    return torch.tensor(np.stack(imgs), dtype=torch.float)

def extractFramesByFreq(video_path: str, frame_freq: int = 2) -> Tensor:
    imgs = []
    cap = cv2.VideoCapture(video_path)

    i = 0
    success, data = cap.read()
    if not success:
        raise ValueError(f'Video {video_path} unreadable!')
    while success:
        if i % frame_freq == 0:
            imgs.append(cv2.resize(data, (224, 224)).transpose(2, 0, 1))
        i += 1
        success, data = cap.read()
    cap.release()
    return torch.tensor(np.stack(imgs), dtype=torch.float)

class VideoDataset(Dataset):
    def __init__(self, data_path: str, train: bool = True, train_ratio: float = 0.8, by_length: bool = True) -> None:
        super(VideoDataset, self).__init__()
        if train_ratio <= 0 or train_ratio >= 1:
            raise ValueError('train_ratio must be in the range of (0, 1)')
        self.x, self.y = [], []
        for ith_class, folder_name in enumerate(listdir(data_path)):
            for file_name in listdir(data_path, folder_name):
                self.x.append(os.path.join(data_path, folder_name, file_name))
                self.y.append(ith_class)
        self.x = np.array(self.x)
        self.y = torch.tensor(self.y, dtype=torch.long)

        skf = StratifiedKFold(n_splits=int(1/(1-train_ratio)), shuffle=True, random_state=42)
        sample_index = next(itertools.islice(skf.split(self.x, self.y), 0, None))[0 if train else 1]
        self.x, self.y = self.x[sample_index], self.y[sample_index]
        self.extractFrames = extractFramesByLength if by_length else extractFramesByFreq
        self.transform = transforms.Compose([
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def __getitem__(self, index): 
        return self.transform(self.extractFrames(self.x[index])), self.y[index]

    def __len__(self): 
        return len(self.y)

class CrossValDataset(Dataset):
    def __init__(self, data_path: str, fold: int = 0, train: bool = True, train_ratio: float = 0.8, by_length: bool = True) -> None:
        super(CrossValDataset, self).__init__()
        if train_ratio <= 0 or train_ratio >= 1:
            raise ValueError('train_ratio must be in the range of (0, 1)')
        self.x, self.y = [], []
        for ith_class, folder_name in enumerate(listdir(data_path)):
            for file_name in listdir(data_path, folder_name):
                self.x.append(os.path.join(data_path, folder_name, file_name))
                self.y.append(ith_class)
        self.x = np.array(self.x)
        self.y = torch.tensor(self.y, dtype=torch.long)

        skf = StratifiedKFold(n_splits=int(1/(1-train_ratio)), shuffle=True, random_state=42)
        sample_index = next(itertools.islice(skf.split(self.x, self.y), fold, None))[0 if train else 1]
        self.x, self.y = [extractFramesByFreq(filename) for filename in self.x[sample_index]], self.y[sample_index]

    def __getitem__(self, index): 
        return self.x[index], self.y[index]

    def __len__(self): 
        return len(self.y)