from torch.utils.data import Dataset
from pathlib import Path
import utils.fsvisit as fsvisit
import torch
import torchvision
import torch.nn.functional as F
import einops
import collections
import pandas as pd



class PhysionVideoDataset(Dataset):
    def _add_mp4_to_list(self, fpath: Path, _):
        if fpath.suffix == ".mp4":
            self.video_paths.append(fpath)

    def __init__(self, base_paths, conf):
        if isinstance(base_paths, str) or not isinstance(base_paths, collections.Collection):
            base_paths = [base_paths]

        self.video_paths = []
        for path in base_paths:
            fsvisit.FSVisitor(
                file_callback=self._add_mp4_to_list
            ).go(Path(path))
        
        assert len(self)>0, "Dataset is empty! Did you set the correct path?"
        assert conf.video_resolution in {256, 128, 64, 32}, "Only power of two resolutions are supported atm"

        self.video_resolution = conf.video_resolution
        self.frameskip = conf.frameskip
    
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        video, _audio, _info = torchvision.io.read_video(self.video_paths[idx], output_format="TCHW", pts_unit="sec")
        # this function returns a tensor of Bytes in [0,255], we want Floats in [0,1]
        video = video.to(torch.float32) / 255
        video = video[[t % self.frameskip == 0 for t in range(video.shape[0])], ...]
        while video.shape[-1] > self.video_resolution:
            video = F.avg_pool2d(video, kernel_size=(2,2), stride=(2,2))
        return video



class PhisionVideoLabeledDataset(PhysionVideoDataset):
    def _add_mp4_and_label_to_lists(self, video_path: Path, labels_megatable):
        if video_path.suffix != ".mp4":
            return

        self.video_paths.append(video_path)
        self.label_paths.append(labels_megatable.loc[video_path.stem].item())

    def __init__(self, base_video_paths, label_file_path, video_resolution):
        if isinstance(base_video_paths, str) or not isinstance(base_video_paths, collections.Collection):
            base_video_paths = [base_video_paths]
        
        labels_megatable = pd.read_csv(label_file_path, index_col=0)
        
        self.video_paths = []
        self.labels = []
        for i in range(len(base_video_paths)):
            vpath = base_video_paths[i]
            fsvisit.FSVisitor(
                file_callback=self._add_mp4_and_label_to_lists
            ).go(Path(vpath), labels_megatable)

        assert len(self)>0, "Dataset is empty! Did you set the correct path?"
        assert video_resolution in {256, 128, 64, 32}, "Only power of two resolutions are supported atm"

        self.video_resolution = video_resolution
    
    # __len__ can be inherited

    def __getitem__(self, idx):
        video = super().__getitem__(idx)
        label_tensor = torch.Tensor([self.labels[idx]])  # label is Boolean
        return video, label_tensor
