from collections import defaultdict
import json
import os
import time
import glob
import pickle
import requests
from pathlib import Path
from fractions import Fraction

import clip
import pandas as pd
import torch
import torchvision
from clip.simple_tokenizer import SimpleTokenizer
from einops.layers.torch import Rearrange
from PIL import Image
from rake_nltk import Rake
from torch.utils.data.dataset import Dataset
from torchvision import transforms

from dataset_loaders.video_retrieval_videodatasets import (VideoDatasetMSRVTT,  # noqa
                                                           VideoDatasetMSVD)

CLIP_TRANSFORM = transforms.Compose([
    transforms.Resize(224, interpolation=Image.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                         (0.26862954, 0.26130258, 0.27577711)),
])

# Transforms on [t,h,w,c] uint8 videos
# "Deterministic or random transformations applied on the batch of Tensor Images
# identically transform all the images of the batch."
VIDEO_AUG = transforms.Compose([
    Rearrange("t h w c -> t c h w"),
    transforms.RandomResizedCrop(size=256, scale=(0.5, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomChoice([
        transforms.ColorJitter(.4,.4,.4, .1),
        transforms.ColorJitter(.4,.4,.4, 0.)
    ]),
    Rearrange("t c h w -> t h w c"),
])

BOT_TEXT_TO_AVOID = ["i am a bot", "i'm a bot", "this is a bot", 
    "redditspeedbot", "this bot", "look at my programming",
    "look at my source code on github",
    "this is a manual removal by a *human moderator*",
    "your post was removed", "this post was removed",
    "your post has been removed",
    "community moderation bot", "unfortunately it has been removed",
    "thank you for your submission",
    "your submission has been removed",
    "if you feel this was done in error",
    "your post breaks", "has been removed for the following reasons",
    "downvote this comment if", "redditdownloader", "repostsleuthbot",
    "vreddit", "savethisvideo", "stabbot", "[removed]", "[deleted]",
    "[excluído]", "savevideo", "this comment"]


def random_blank(strs, p):
    for i in range(len(strs)):
        if torch.rand([]) < p:
            strs[i] = ''
    return strs

def partition_dataframe(df, root=None, split=None):
    """
    Partition into train/test/val

    df: Pandas dataframe from CSV
    split: 'train', 'test' or 'val'
    """

    mp4s = df.video_path.tolist()
    ids = [x.split("/")[-1].split(".")[0] for x in mp4s]

    # The least significant digit of the base36 id is quasi-random
    # so use it to partition into train and test
    digits = "0123456789abcdefghijklmnopqrstuvwxyz"
    digit_split = {}
    digit_split["test"] = set(digits[0:4])
    digit_split["val"] = set(digits[4:8])
    digit_split["train"] = set(digits[8:])

    if root is not None:
        # Check for missing files
        available_mp4s = glob.glob(os.path.join(root, '**/*.mp4'), recursive=True)
        available_ids = set(x.split("/")[-1].split(".")[0] for x in available_mp4s)

        print('CSV: %d Available on Disk: %d' % (len(ids), len(set(ids).intersection(available_ids))))

        keep = [id[-1] in digit_split[split] and id in available_ids for id in ids]
    else:
        keep = [id[-1] in digit_split[split] for id in ids]

    return df[keep]


def load_features(df, path):
    # Load from PTH
    features_stored = torch.load(path)

    if "reddit_id_to_comment_id" in features_stored:
        # Handle comments
        reddit_ids = list(features_stored["reddit_id_to_comment_id"].keys())
        embeddings = features_stored["embeddings"]
        lookup = {int(el): i for i, el in enumerate(reddit_ids)}
        sel = [lookup[rid] for rid in df.reddit_id]
        # embeddings is a list of lists of zero-or-more tensors
        feats = [embeddings[s] for s in sel]
        assert len(feats) == len(df)
        return feats
    else:
        # Handle not
        assert features_stored["reddit_ids"].dtype is torch.int64
        assert features_stored["embeddings"].dtype is torch.float32
        lookup = {int(el): i for i, el in enumerate(features_stored["reddit_ids"])}
        sel = [lookup[rid] for rid in df.reddit_id]
        feats = features_stored["embeddings"][sel]
        assert feats.shape[0] == len(df)
        return feats


class VisionTitleCommentDatasetBase(Dataset):
    def split_dataset(self, df, train, test):
        if test:
            assert not train
            return partition_dataframe(df, split="test")
        else:
            return partition_dataframe(df, split="train" if train else "val")

    def should_add_comments(self, add_comments, train):
        cases = {"always": [True, True],
                 "train_only": [False, True],
                 "never": [False, False]}

        return cases[add_comments][int(train)]

    def _tokenise(self, texts, max_len=77):
        if isinstance(texts, str):
            texts = [texts]
        sot_token = self.tokenizer.encoder["<|startoftext|>"]
        eot_token = self.tokenizer.encoder["<|endoftext|>"]
        all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
        result = torch.zeros(len(all_tokens), max_len, dtype=torch.long)

        for i, tokens in enumerate(all_tokens):
            if len(tokens) >= max_len:
                # summarise text by extracting keywords
                self.rake.extract_keywords_from_text(texts[i])
                a = self.rake.get_ranked_phrases()
                tokens = [sot_token] + self.tokenizer.encode(' '.join(a)) + [eot_token]
                if len(tokens) >= max_len:
                    result[i, :max_len] = torch.tensor(tokens[:max_len - 1] + [eot_token])
                else:
                    result[i, :len(tokens)] = torch.tensor(tokens)
            else:
                result[i, :len(tokens)] = torch.tensor(tokens)
        return result

    def preprocess_comments(self, comments, sampling=None, num_comms=2):
        if num_comms == 0:
            return []

        comments = [comm for comm in comments if
            all(s not in comm.lower() for s in BOT_TEXT_TO_AVOID)]

        if len(comments) >= num_comms:
            if sampling == 'random':
                idxs = torch.multinomial(torch.ones(len(comments)), len(comments))
                comments = [comments[idx] for idx in idxs[:num_comms]]

            elif sampling is None:
                comments = comments[:num_comms]
        while len(comments) < num_comms:
            comments.append("")

        return comments

    def preprocess_imlabels(self, imlabels, num_imlabels, augment_imlabels=False):
        if num_imlabels == 0:
            return []

        if self.train:
            if not imlabels:
                idxs = []
            elif len(imlabels) == 1:
                idxs = [0]
            else:
                nlabels = torch.randint(1, len(imlabels), [])
                idxs = torch.multinomial(torch.ones(nlabels), nlabels)

            descs = [imlabels[i]["description"] for i in idxs]
            descs = descs[:num_imlabels]

            if augment_imlabels:
                descs = [self._augment_imlabel(x) for x in descs]
        else:
            descs = [iml["description"] for iml in imlabels][:num_imlabels]

        while len(descs) < num_imlabels:
            descs.append("")

        return descs

    def _augment_imlabel(self, dispname):
        meta = self.imlabels_synonyms.get(dispname.lower(), None)

        if meta is None:
            return dispname

        display = meta["DisplayName"]
        desc = meta["description"]
        syns = meta["synonyms"]
        wikipara = meta["wiki_paragraph"]

        candidates = [display]
        if not pd.isna(desc):
            candidates.append(desc)

        for syn in json.loads(syns):
            candidates.append(syn)

        if not pd.isna(wikipara):
            wikisentences = [x.strip()
                             for x in wikipara.split('.') if len(x) > 60]
            if len(wikisentences):
                ri = torch.randint(0, len(wikisentences), [])
                candidates.append(wikisentences[int(ri)])

        idx = torch.randint(0, len(candidates), [])

        return candidates[int(idx)]

    def _load_imlabels_synonyms(self):
        df = pd.read_csv("/data/REDACTED_project/imlabels_synonyms.csv")
        syndict = {x["DisplayName"].lower(): x.to_dict()
                   for _, x in df.iterrows()}
        return syndict


    def _load_reddit(self, df, file_extension=".mp4"):
        files = [x[len("results/"):-4] + file_extension for x in df.video_path]
        self.filenames += [os.path.join(self.root, x) for x in files]
        self.ids += df.reddit_id.to_list()
        self.titles += df.title.to_list()
        self.video_lengths += df.video_length.to_list()
        self.labels += [json.loads(s) for s in df.imlabels]
        self.comments += [json.loads(c) for c in df.comments]

        print(len(self.ids), "reddit videos")

    def _load_kinetics(self, df):
        exclude_kw = set(["video", "gopro", "webcam"])
        nk = 0
        for ki in range(len(df)):
            row = df.iloc[ki]
            vp = os.path.join(self.kinetics_root, row.video_path)
            split_k700 = row.split_k700
            split_k400 = row.split_k400

            istrain = split_k700 == "train" and (
                split_k400 == "train" or pd.isna(split_k400))

            if istrain and os.path.exists(vp):
                self.filenames.append(vp)
                self.ids.append(-1)
                self.titles.append(row.title_en)
                self.video_lengths.append(row.video_length)
                kw = [] if pd.isna(row.keywords) else json.loads(row.keywords)
                self.labels.append([{"description": k} for k in kw if k.lower() not in exclude_kw])

                comms = [] if pd.isna(row.comments) else json.loads(row.comments)

                if not pd.isna(row.description_en):
                    desc_sentences = [x.strip() for x in row.description_en.split('.') if len(x) > 60]
                    comms.extend(desc_sentences)

                self.comments.append(comms)

                nk += 1
        print(nk, "kinetics videos")
        assert nk > 400000

    def _read_video(self, idx):
        id = self.ids[idx]
        video_path = self.filenames[idx]
        video_length = min(60, self.video_lengths[idx])
        frame_stride = self.frame_strides[torch.randint(0, len(self.frame_strides), [])]

        segment_duration_sec = self.nframes / (self.reference_fps / frame_stride)

        # Often the reddit videos have an offset of 1.4s to the
        # start time meaning that there are no video frames
        # in the range (0, 1.4).
        # The offset can be obtained with ffprobe, eg:
        #
        #     prob = ffmpeg.probe(video_path)
        #     start_time = float(prob["streams"][0]["start_time"])
        #
        # unfortunately this is not exposed in the torchvision api
        # and calling ffprobe makes things slow, so for now just
        # assume 1.4 for all videos (TODO: precompute for all videos)
        ffmpeg_start_time = 0 if id == -1 else 1.4

        # For simplicity just use milliseconds as the timebase
        # (which defines the unit used in video_pts_range, given
        # that the range must be integers)
        #
        # The video's native timebase is obtained with
        #
        #     prob = ffmpeg.probe(video_path)
        #     tb = Fraction(prob["streams"][0]["time_base"])
        #
        # although the given timebase is converted to ffmpeg's
        # internal AV_TIME_BASE so the choice is somewhat arbitrary
        tb = Fraction(1, 1000)

        if self.train:
            start_lower = ffmpeg_start_time
            start_upper = max(0, video_length - segment_duration_sec)
            segment_start_sec = (start_lower - start_upper) * torch.rand([]).item() + start_upper
        else:
            segment_start_sec = 0

        segment_end_sec = segment_start_sec + segment_duration_sec

        video_start = int(segment_start_sec / tb)
        video_end = int(segment_end_sec / tb)

        # For now use this private method since it allows resizing
        # on the ffmpeg side which is faster.
        # A large seek_frame_margin seems to be needed
        # to seek accurately
        vid, _, _ = torchvision.io._read_video_from_file(
            video_path,
            seek_frame_margin=5,
            video_width=self.video_read_width,
            video_height=self.video_read_height,
            read_audio_stream=False,
            video_timebase=tb,
            video_pts_range=(video_start, video_end),
        )

        if vid.shape[0] == 0:
            print('Zero len vid, trying fallback', video_path)
            vid, _, _ = torchvision.io._read_video_from_file(
                video_path,
                video_width=self.video_read_width,
                video_height=self.video_read_height,
                read_audio_stream=False,
                video_timebase=Fraction(1),
                video_pts_range=(0, 5),
            )

        if vid.shape[0] == 0:
            print("Fallback failed", video_path)
            vid = torch.zeros(8,300,300,3, dtype=torch.uint8)

        idxs = torch.floor(torch.linspace(0, len(vid)-1,self.nframes)).to(torch.int64)
        vid = torch.index_select(vid, 0, idxs)

        vid = self.video_tfm(vid)

        return vid


class VideoDatasetSegments(VisionTitleCommentDatasetBase):
    """
    A video loader that selects a random segment from 
    each video and does data augmentation through cropping,
    variable speed and color jitter.

    Returns frames in [t h w c] order
    """

    def __init__(self, csv_file, root, train=True, test=False,
                 add_comments="train_only",
                 num_comms=2, comment_sampling="random", num_imlabels=0,
                 use_kinetics_train=None,
                 kinetics_csv=None,
                 kinetics_root=None):
        
        self.train = train
        self.root = root
        self.kinetics_root = kinetics_root
        self.num_comms = num_comms
        self.num_imlabels = num_imlabels
        self.comment_sampling = comment_sampling if train else None

        self.add_comments = self.should_add_comments(add_comments, train)

        self.video_read_height = 300
        self.video_read_width = 0
        self.nframes = 8
        self.reference_fps = 30

        self.frame_tfm = CLIP_TRANSFORM
        self.tokenizer = SimpleTokenizer()
        self.rake = Rake()

        if self.train:
            self.video_tfm = VIDEO_AUG
            self.frame_strides = (4, 8, 16, 32)
        else:
            self.video_tfm = transforms.Compose([])
            self.frame_strides = (16,)

        self.ids = []
        self.filenames = []
        self.titles = []
        self.video_lengths = []
        self.labels = []
        self.comments = []

        use_reddit = (not train) or (use_kinetics_train != "only")
        use_kinetics = train and use_kinetics_train in ("combine", "only")

        if use_reddit:
            df = pd.read_csv(csv_file)
            df = self.split_dataset(df, train, test)
            self._load_reddit(df)

        if use_kinetics:
            df_kinetics = pd.read_csv(kinetics_csv)
            self._load_kinetics(df_kinetics)

        if num_imlabels:
            self.imlabels_synonyms = self._load_imlabels_synonyms()


    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        id = self.ids[idx]
        title = self.titles[idx]
        imlabels = self.labels[idx]
        comments = self.comments[idx]

        vid = self._read_video(idx)

        images = [self.frame_tfm(Image.fromarray(frame.numpy())) for frame in vid]
        vid = torch.stack(images)

        title_tok = self._tokenise([title])[0]

        if self.add_comments:
            comments = self.preprocess_comments(comments,
                                                sampling=self.comment_sampling,
                                                num_comms=self.num_comms)

            imlabels = self.preprocess_imlabels(
                imlabels, num_imlabels=self.num_imlabels)
                
            comments_labels_tok = self._tokenise(comments + imlabels)

            if torch.rand([]) < 0.0001:
                print("Debug dataloader -- title:", title, "comms:", comments, "imlabels:", imlabels)
        else:
            comments_labels_tok = self._tokenise([""])

        meta = {"id": id}
        return vid, title_tok, comments_labels_tok, meta


class VideoDatasetFirst32(Dataset):
    """A simple video loader that just returns the first 32 frames
    rescaled to 128x172 ignoring aspect ratio and doesn't do
    any frame rate resampling

    Tensor is padded with black frames if there are under 32
    frames

    Returns frames in [c t h w] order
    """

    def __init__(self, csv_file, root, text_features=None, train=True, 
                 should_partition_dataframe=True, clip_preprocess=False):
        
        self.train = train

        self.height = 128
        self.width = 171
        self.nframes = 32

        df = pd.read_csv(csv_file)
        if should_partition_dataframe:
            df = partition_dataframe(df, root=root, split="train" if train else "val")

        self.video_files = []

        for i in range(len(df)):
            vp = df.video_path.iloc[i][len("results/") :]
            vp = os.path.join(root, vp)
            self.video_files.append(vp)
        
        self.ids = df.reddit_id.to_list()
        self.titles = df.title.to_list()
        self.clip_preprocess = clip_preprocess

        if clip_preprocess:
            self.tfms = CLIP_TRANSFORM
        else:
            self.tfms = transforms.Compose(
                [
                    Rearrange("t h w c -> t c h w"),
                    transforms.ConvertImageDtype(torch.float32),
                    # For ig65m https://github.com/moabitcoin/ig65m-pytorch/blob/master/ig65m/cli/extract.py#L64
                    transforms.Normalize(
                        mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]
                    ),
                ]
            )

        # Ordering used in ig65m and torchvision
        # https://github.com/pytorch/vision/tree/master/torchvision/models/video
        self.final_rearrange = Rearrange("t c h w -> c t h w")
        
        if text_features is not None:
            self.text_feats = load_features(df, text_features)

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path = self.video_files[idx]
        id = self.ids[idx]
        title = self.titles[idx]

        # For now use this private method since it allows resizing
        # on the ffmpeg side which is faster
        # Get the first 4 seconds which should get us at least 32
        # frames at reasonable frame rates
        vid, _, _ = torchvision.io._read_video_from_file(
            video_path,
            video_width=self.width,
            video_height=self.height,
            read_audio_stream=False,
            video_timebase=Fraction(1),
            video_pts_range=(0, 4),
        )

        vid = vid[0 : self.nframes, ...]

        if vid.shape[0] < self.nframes:
            # Padding
            length = vid.size(0)
            out_tensor = vid.new_full((self.nframes, self.height, self.width, 3), 0.0)
            if length == 0:
                print("Zero length video!", video_path)
            else:
                out_tensor[:length, ...] = vid
            vid = out_tensor

        if self.clip_preprocess:
            images = []
            for frame in vid:
                images.append(self.tfms(
                    Image.fromarray(frame.numpy()).convert("RGB")))
            vid = torch.stack(images)
            try:
                text = clip.tokenize(title)
            except Exception as e:
                print(f'Failed to tokenize {title}', str(e))
                text = clip.tokenize(title[:20])
        else:
            vid = self.tfms(vid)
            vid = self.final_rearrange(vid)
            text = self.text_feats[idx]
        meta = {"id": id}
        return vid, text, meta


class VideoDatasetFirst1800(Dataset):
    """A simple video loader that returns the first 1800 frames
    (which would be the first minute at 30fps), returning
    a shorter tensor if there are not enough frames, but with
    a minimum length of 32 frames padded with black frames.

    Framerate is not resampled

    To emulate preprocessing used in collab experts, videos
    are first resized to height 256 (preserving aspect ratio)
    and then to smaller edge 128 (preserving aspect ratio) and
    then a 112x112 center crop is taken.

    Returns frames in [c t h w] order
    """

    def __init__(self, csv_file, root, train=True, should_partition_dataframe=True):
        self.train = train

        self.video_read_height = 256
        self.height = 128
        self.crop_size = 112
        self.nframes = 1800
        self.min_nframes = 32

        df = pd.read_csv(csv_file)
        if should_partition_dataframe:
            df = partition_dataframe(df, root=root, split="train" if train else "val")

        self.video_files = []

        for i in range(len(df)):
            vp = df.video_path.iloc[i][len("results/") :]
            vp = os.path.join(root, vp)
            self.video_files.append(vp)

        self.tfms = transforms.Compose(
            [
                Rearrange("t h w c -> t c h w"),
                transforms.Resize(128),
                transforms.CenterCrop(112),
                transforms.ConvertImageDtype(torch.float32),
                # For ig65m https://github.com/moabitcoin/ig65m-pytorch/blob/master/ig65m/cli/extract.py#L64
                transforms.Normalize(
                    mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]
                ),
            ]
        )

        # Ordering used in ig65m and torchvision
        # https://github.com/pytorch/vision/tree/master/torchvision/models/video
        self.final_rearrange = Rearrange("t c h w -> c t h w")

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path = self.video_files[idx]

        # For now use this private method since it allows resizing
        # on the ffmpeg side which is faster
        range_upper_bound = self.nframes // 15  # time if it was 15fps
        vid, _, _ = torchvision.io._read_video_from_file(
            video_path,
            video_width=0,
            video_height=self.video_read_height,
            read_audio_stream=False,
            video_timebase=Fraction(1),
            video_pts_range=(0, range_upper_bound),
        )

        vid = vid[: self.nframes]
        length = vid.size(0)

        if length > 0:
            vid = self.tfms(vid)
        else:
            vid = vid.float()

        if length < self.min_nframes:
            # Padding
            out_tensor = vid.new_full(
                (self.min_nframes, 3, self.crop_size, self.crop_size), 0.0
            )
            if length == 0:
                print("Zero length video!", video_path)
            else:
                out_tensor[:length, ...] = vid
            vid = out_tensor

        vid = self.final_rearrange(vid)

        return vid, {}


def sample_instance(feature_list, sampling):
    """ Sample tensor from a list

    Args:
        features_list: List of 1D tensors
        sampling: One of ``['all', 'first', 'random']``

    Returns: Depending on ``sampling``:
        - 'all': [list_len, embedding_size] tensor containg all the
            embeddings stacked
        - 'first': [embedding_size] shape tensor of the first element
        - 'random': [embedding_size] shape tensor of a random element
    """
    assert isinstance(feature_list, list)
    if sampling == "first":
        return feature_list[0]
    elif sampling == "random":
        ri = torch.randint(0, len(feature_list), ())
        return feature_list[ri]
    elif sampling == "all":
        # NB this won't work when doing batching since
        # tensor sizes will vary
        return torch.stack(feature_list)
    else:
        raise Exception("Unknown sampling method")


def sample_if_list(feature_tensor_or_list, sampling):
    """ Convenience function that will return a tensor as-is
    but if given a list of tensors will return one given by the
    sampling method.

    Args:
        feature_tensor_or_list: Tensor or list of tensors
        sampling: One of ``['all', 'first', 'random']``
        - all
    """
    if isinstance(feature_tensor_or_list, list):
        return sample_instance(feature_tensor_or_list, sampling)
    elif torch.is_tensor(feature_tensor_or_list):
        return feature_tensor_or_list


class FeaturesDataset(Dataset):
    """Load precomputed reddit features

    Args:
        csv_file: A csv file of reddit posts, which should at minimum have
        "reddit_id" and "video_path" columns. Returned features will be ordered
        as given in this file.

        input_features: Specification of the input features, with possible forms:
            - "filename.pth" : Load a single file
            - ["file1.pth", "file2.pth", ...] : Load multiple files, each
                will be a separate returned input from __getitem__
            - ["a.pth", ["b.pth", "c.pth"], ...] etc : Features given in nested
                lists will be concatenated into one long feature, eg
                feature_a, feature_bc

            Files are in pytorch format, containing a dict
                {"reddit_ids": (torch.int64, shape N)
                 "embeddings": (torch.float32, shape N x embedding_size)}
            Or for comments (multiple comments per reddit id):
                {"reddit_id_to_comment_id": (dict[int, List[str]])
                 "embeddings": (List[List[torch.float32]])}

        target_features: .pth file of target features (to be used in loss
            function but not passed to network)

        train (bool): Determines how csv is partitioned (see partition_dataframe)
            True for training set, False validation set

        train_comment_sampling (string):
            One of ``['all', 'first', 'random']`` specifying how to sample comments
            at train time, since there are multiple comments per reddit post.

            For the different options the comment embeddings returned by __getitem__ are as follows:

            - 'all': [n_comments, embedding_size] tensor containg embeddings
                of all the captions for the video
            - 'first': [embedding_size] shape tensor of the first comment
            - 'random': [embedding_size] shape tensor of a random comment

        test_comment_sampling (string):
            As above but at test time

    """

    def __init__(
        self,
        csv_file,
        input_features=None,
        target_features=None,
        train=True,
        train_comment_sampling=None,
        test_comment_sampling=None,
    ):
        self.train = train

        self.feature_sampling = (
            train_comment_sampling if train else test_comment_sampling
        )

        df = pd.read_csv(csv_file)
        df = partition_dataframe(df, split="train" if train else "val")

        # allow string or list of string to load multiple input features
        if isinstance(input_features, str):
            input_features = [input_features]

        # Allow up to one level of nesting
        self.feats = [
            [load_features(df, feats_inner) for feats_inner in feats]
            if isinstance(feats, list)
            else load_features(df, feats)
            for feats in input_features
        ]

        self.targets = None
        if target_features:
            self.targets = load_features(df, target_features)

    def __len__(self):
        return len(self.feats[0])

    def __getitem__(self, idx):
        input = []
        for feat in self.feats:
            if isinstance(feat, list):
                # If feat is a list concatenate the features
                input.append(
                    torch.cat(
                        [sample_if_list(f[idx], self.feature_sampling) for f in feat]
                    )
                )
            else:
                input.append(sample_if_list(feat[idx], self.feature_sampling))

        meta = {}
        if self.targets is not None:
            meta["target"] = self.targets[idx]
        return (*input, meta)


class ImTextDataset(VisionTitleCommentDatasetBase):
    """ Load thumbnail images, titles, and comments with CLIP preprocessing.
    TODO: add option for other preprocessing.
    Args:
        csv_file (str): A csv file of reddit posts
        root (str): root directory prefix for the images and videos data
        train (bool): determines how csv is partitioned (see partition_dataframe)
            True for training set, False validation set 
        add_comments (str): ['always', 'train_only', 'never']
        num_comms (int): number of comments per post to add when adding comments
        comment_sampling (str or None): if set to "random" it will sample
            random comments per post
        num_imlabels (int): number of labels per post to add when adding image labels
    """
    def __init__(self, csv_file, root='/data/reddit-results', train=True, test=False,
                add_comments="train_only",
                num_comms=0, comment_sampling="random", num_imlabels=0,
                cached_vision_features=None, random_words=False):
        
        self.train = train
        self.root = root
        self.num_comms = int(num_comms)
        self.num_imlabels = int(num_imlabels)
        self.random_words = random_words
        self.comment_sampling = comment_sampling if train else None
        self.cached_vision_features = cached_vision_features

        self.add_comments = self.should_add_comments(add_comments, train)

        self.ids = []
        self.filenames = []
        self.titles = []
        self.video_lengths = []
        self.labels = []
        self.comments = []

        df = pd.read_csv(csv_file)
        df = self.split_dataset(df, train, test)
        self._load_reddit(df, file_extension=".jpg")

        self.preprocess = CLIP_TRANSFORM
        self.tokenizer = SimpleTokenizer()
        self.rake = Rake()

        if num_imlabels:
            self.imlabels_synonyms = self._load_imlabels_synonyms()
        
        if random_words:
            word_site = "https://www.mit.edu/~ecprice/wordlist.10000"
            response = requests.request(method='GET', url=word_site)
            txt = response.text
            self.random_word_list = txt.splitlines()
        
        if cached_vision_features is not None:
            self.vision_feats = load_features(df, cached_vision_features)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        im_path = self.filenames[idx]
        title = self.titles[idx]
        id = self.ids[idx]
        comments = self.comments[idx]
        imlabels = self.labels[idx]
        
        if self.cached_vision_features is not None:
            im = self.vision_feats[idx]
        else:
            im = self.preprocess(Image.open(im_path).convert("RGB"))

        title_tok = self._tokenise([title])[0]

        if self.add_comments:
            if self.random_words:
                indices = torch.randint(len(self.random_word_list), (max(self.num_comms, self.num_imlabels),)).numpy().tolist()
                comments = [self.random_word_list[w] for w in indices]
                imlabels = []
            else:
                comments = self.preprocess_comments(comments,
                                                    sampling=self.comment_sampling,
                                                    num_comms=self.num_comms)

                imlabels = self.preprocess_imlabels(
                    imlabels, num_imlabels=self.num_imlabels)
                
            comments_labels = self._tokenise(comments + imlabels)

            if torch.rand([]) < 0.0001:
                print("Debug dataloader -- title:", title, "comms:", comments, "imlabels:", imlabels)
        else:
            comments_labels = self._tokenise([""])

        input = (im, title_tok, comments_labels)

        meta = {'id': id}
        
        return (*input, meta)



if __name__ == "__main__":
    import numpy as np
    TEST_CSV = "data_symlink/filtered_reddit_video_dataset_v1_db_2020-11-16_22-14-57_deduplicated_dataset_0.00075_imlabels.csv"
    # ds = VideoDatasetFirst32(TEST_CSV, root='/data/reddit-results')
    # print(ds[123][0].shape)

    ds = VideoDatasetSegments(
        csv_file=TEST_CSV,
        root="/data/reddit-results/",
        num_imlabels=3,
        num_comms=3,
        kinetics_csv="data_symlink/kinetics700_havedescs.csv",
        use_kinetics_train="combine",
        kinetics_root="/data"
        )

    if True:
        # Write pngs
        for dsidx_ in range(40):
            dsidx = int(torch.randint(0, len(ds), []))
            for n in range(4):
                x = ds[dsidx]
                vid = x[0]
                vid = vid - vid.min()
                vid = vid/vid.max()

                augs_dir = '/tmp/frames/frames_%s' % os.path.basename(ds.filenames[dsidx]).split('.')[0]
                os.makedirs(augs_dir, exist_ok=True)

                if len(vid.shape) != 4:
                    vid = vid.unsqueeze(0)

                for i, im in enumerate(vid):
                    im = Image.fromarray((im.permute(1,2,0) * 255).to(torch.uint8).numpy())
                    os.makedirs(os.path.join(augs_dir, 'aug%d' % n), exist_ok=True)
                    im.save(os.path.join(augs_dir, 'aug%d/frame%03d.png' % (n,i)))

    if False:
        for nw in [30,40]:
            # Benchmark dataloader
            from torch.utils.data import DataLoader
            data_loader = DataLoader(
            ds,
            batch_size=40,
            num_workers=nw,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            prefetch_factor=4
            )

            tic = time.time()
            times = []
            nbatches = 0
            for dl in data_loader:
                dltime = time.time() - tic
                print('Batch time %.3f' % dltime, 'mean', np.mean(times))
                time.sleep(0.2)
                times.append(dltime)
                tic = time.time()
                nbatches += 1

                if nbatches > 200:
                    break

            print(nw, np.mean(times), np.std(times))

