import numpy as np
import os
import pandas as pd
import torch
import decord
import timeit
import pickle
import json
import random
import cv2
import ffmpeg
import glob

from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.autograd.variable import Variable
from torchvision import transforms
from decord import VideoReader, cpu, gpu
from datetime import timedelta
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data._utils.collate import default_collate

# import logging
# logging.basicConfig(filename="error.txt",level=logging.DEBUG)
# logging.captureWarnings(True)


# print('imported!!!')
# os.environ['DECORD_EOF_RETRY_MAX'] = '40960'
decord.bridge.set_bridge('torch')
# test_remove = ['084k_RL3ApU_000109_000119.mp4', '2xWiEVNUvhE_000064_000074.mp4', '305P2f9_lko_004145_004155.mp4',
#                'B4bn9G6__sY_000086_000096.mp4', 'BvBVQmm2RcM_000082_000092.mp4', 'CxjipYE57Yo_000199_000209.mp4',
#                'IhanWvpHGu8_001243_001253.mp4', 'Lw14NH9kAqE_000759_000769.mp4',
#                ' XFkykETgkoo_002967_002977.mp4', 'jJFqy6yiXzQ_000024_000034.mp4',
#                'kinMMqkswUk_000120_000130.mp4', 'y7cYaYX4gdw_000047_000057.mp4']




def multiple_samples_collate(batch):
    """
    Collate function for repeated augmentation. Each instance in the batch has
    more than one sample.
    Args:
        batch (tuple or list): data batch to collate.
    Returns:
        (tuple): collated data batch.
    """
    inputs, labels = zip(*batch)
    inputs = [x for x in inputs if x is not None]
    labels = [x for x in labels if x is not None]

    inputs, labels, = (
        default_collate(inputs),
        default_collate(labels),
    )

    return inputs, labels


class CustomBatchSampler(DistributedSampler):
    r"""Yield a mini-batch of indices. The sampler will drop the last batch of
            an image size bin if it is not equal to ``batch_size``

    Args:
        examples (dict): List from dataset class.
        batch_size (int): Size of mini-batch.
    """

    def __init__(self, data, num_replicas, rank, batch_size, shuffle, spatial_flex, static_tokens):
        # super().__init__(data, num_replicas, rank, shuffle)
        self.batch_size = batch_size
        self.data = data
        self.spatial_flex = spatial_flex
        self.static_tokens = static_tokens
        if shuffle:
            random.shuffle(self.data)


    def __iter__(self):
        batch = []
        num_frames = random.choice([4, 8, 16, 32, 64])
        if self.spatial_flex:
            reso = random.choice([96, 128, 224, 384])#, 512, 640])
        elif self.static_tokens:  # need resos divisible by 14 to determine patch size that results in 14 tokens
            reso = random.choice([98, 126, 224, 392])
        for index, sample in enumerate(self.data):
            if self.spatial_flex or self.static_tokens:
                batch.append([index, num_frames, reso])
            else:
                batch.append([index, num_frames])


            if len(batch) == self.batch_size:
                yield batch
                num_frames = random.choice([4, 8, 16, 32, 64])
                if self.spatial_flex:
                    reso = random.choice([96, 128, 224, 384])
                elif self.static_tokens:  # need resos divisible by 14 to determine patch size that results in 14 tokens
                    reso = random.choice([98, 126, 224, 392])
                batch = []

    def __len__(self):
        return len(self.data) // self.batch_size
        # return 20 // self.batch_size


class K700(Dataset):
    def __init__(self, data_split, num_frames, resolution):
        self.labels = []
        self.num_frames = num_frames
        self.resolution = resolution
        print('resolution: ', resolution)

        self.root = '/share/datasets/Kinetics700/'
        self.data_split = data_split


        if data_split == 'train':
            self.data = open("k700train.csv", 'r').readlines()[1:]
            self.data = [x.strip('\n') for x in self.data]
            cleaned_data = []
            for row in self.data:
                label, vid, start, stop, _ = row.split(',')
                vid = f'{vid}_{start.zfill(6)}_{stop.zfill(6)}'
                vid_path = os.path.join(self.root + f'train/{label}', vid + '.mp4')
                if os.path.exists(vid_path):
                    cleaned_data.append(row)
            self.data = cleaned_data

        else:
            self.data = open("k700val.csv", 'r').readlines()[1:]
            self.data = [x.strip('\n') for x in self.data]

        print(len(self.data))
        # self.data = self.data

        self.labels = open("/squash/kinetics400_dataset-lz4/clsIdx.csv", 'r').readlines()[1:]
        self.labels = [x.split(',')[0] for x in self.labels]

        df = pd.read_csv('k700val.csv')
        k700labels = df['label'].unique().tolist()
        print(len(k700labels))
        print(type(self.data), self.data[0])
        remove = []
        for label in k700labels:
            if label in self.labels:
                remove.append(label)

        k700labels = [x for x in k700labels if x not in remove]
        print(len(k700labels))
        self.labels = k700labels



        # self.data = self.data[:50]
        print(len(self.data), self.data[0], len(self.labels), self.labels[0])
        self.data = [x for x in self.data if x.split(',')[0] in self.labels]
        print(len(self.data), self.data[0], len(self.labels), self.labels[0])

        self.transforms = transforms.Compose([
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            transforms.Resize([self.resolution, self.resolution]), # why not [256, 256]?
            # transforms.CenterCrop(size=(self.resolution, self.resolution))

        ])
        self.trcount = 0
        self.tscount = 0



    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        if self.data_split == 'train':
            label, vid, start, stop, _ = self.data[index].split(',')
            vid = f'{vid}_{start.zfill(6)}_{stop.zfill(6)}'
            vid_path = os.path.join(self.root + f'train/{label}', vid+'.mp4')
            # print(vid_path)
            try:
                vr = VideoReader(vid_path)
            except Exception as e:
                self.trcount += 1
                print(e, self.trcount, flush=True)
                # print('entered')
                # print()
                # print()
                label, vid, start, stop, _, = self.data[index-random.randrange(5, 10)].split(',')
                vid = f'{vid}_{start.zfill(6)}_{stop.zfill(6)}'
                vid_path = os.path.join(self.root + f'train/{label}', vid + '.mp4')
                vr = VideoReader(vid_path)
                # return None, None
            frame_indexer = np.linspace(0, len(vr) - 1, self.num_frames)
            frames = vr.get_batch(frame_indexer)
            frames = frames.permute(0, 3, 1, 2) / 255.
            # torch.save(frames, 'kin_frame.pt')
            # exit()
            frames = self.transforms(frames).permute(1, 0, 2, 3)
            label = self.labels.index(label.strip('\"'))


        elif self.data_split == 'test':
            label, vid, start, stop, _ = self.data[index].split(',')
            vid = f'{vid}_{start.zfill(6)}_{stop.zfill(6)}'
            vid_path = os.path.join(self.root + f'val/{label}', vid + '.mp4')
            try:
                vr = VideoReader(vid_path)
            except Exception:
                self.tscount += 1
                print(e, self.tscount, flush=True)
                # print('entered')
                # print()
                # print()
                label, vid, start, stop, _ = self.data[index +1].split(',')
                vid = f'{vid}_{start.zfill(6)}_{stop.zfill(6)}'
                vid_path = os.path.join(self.root + f'val/{label}', vid + '.mp4')
                vr = VideoReader(vid_path)
            frame_indexer = np.linspace(0, len(vr) - 1, self.num_frames)
            frames = vr.get_batch(frame_indexer)
            frames = frames.permute(0, 3, 1, 2) / 255.
            frames = self.transforms(frames).permute(1, 0, 2, 3)
            label = self.labels.index(label.strip('\"'))

        return frames, label


if __name__ == '__main__':
    shuffle = False
    tr_dataloader_gen = K700('train', 2, 128)
    trdataloader = DataLoader(tr_dataloader_gen, num_workers=8, batch_size=16, collate_fn=multiple_samples_collate)

    # for frames, label in tqdm(trdataloader):
    #     print(frames.shape)


    ts_dataloader_gen = K700('test', 2, 128)
    tsdataloader = DataLoader(ts_dataloader_gen, num_workers=8, batch_size=16, collate_fn=multiple_samples_collate)

    for frames, label in tqdm(tsdataloader):
        print(frames.shape)