from multiprocessing import Pool
import numpy as np
import cv2
import pathlib
import os
import pandas as pd
from PIL import Image
from itertools import chain, takewhile
from torch.utils.data import Dataset
import torch
from torchvision.transforms import ToPILImage
import torchvision.transforms.functional as F
import random
from .data_utils import mean_sub
import utils

"""
    A (unfortunately) hacky way to parallelize the data-indexing process on these massive
    video datasets.
"""

def tp_init_(obj_):
    global self
    self = obj_

class RawVideoFileDataset(Dataset):
    """
        Special data loader for when you don't have frame-by-frame files, only videos.
        Works with any format that OpenCV can deal with. Assumed directory structure:

        root_dir
            class1
                video0.mp4/*.(avi|mp4|mov|etc.)
                video1
                ...
            class_N
                video0
                ...

        All video processing takes place over OpenCV (cv2).
    """
    def __init__(self, root_dir, splits_dir, split, load_width, load_height, load_classes=None, max_frames=16, transform=None, transform_mode='frame', mean_subtract=False, apply_prob=0.5, seed=42, keep_class_order=False):
        self.root_dir = root_dir
        self.splits_dir = splits_dir
        self.split = split # 'train', 'test', 'extra', etc.
        self.n_frames = max_frames
        self.load_width =  load_width
        self.load_height = load_height
        np.random.seed(seed)
        random.seed(seed)
        """
            I'm an idiot and trained the model on a hard-coded class order,
            then deleted the data on accident. Let this comment be a lesson to the
            reader to USE A DETERMINISTIC ORDER. Furthermore, NEVER use os.listdir
            to get the order of something important.
        """
        if 'ucf101' in root_dir: 
            classes = ['FrontCrawl', 'Mixing', 'BoxingPunchingBag', 'BodyWeightSquats', 'Knitting', 'JumpingJack', 'Bowling', 'PlayingPiano', 'HorseRiding', 'TableTennisShot', 'Archery', 'BabyCrawling', 'TaiChi', 'Rafting', 'BenchPress', 'CricketShot', 'MoppingFloor', 'ThrowDiscus', 'Rowing', 'PoleVault', 'ApplyEyeMakeup', 'RockClimbingIndoor', 'Shotput', 'HulaHoop', 'BlowingCandles', 'PlayingFlute', 'GolfSwing', 'Skiing', 'VolleyballSpiking', 'ParallelBars', 'PullUps', 'Punch', 'YoYo', 'SalsaSpin', 'IceDancing', 'SoccerJuggling', 'RopeClimbing', 'PlayingViolin', 'Typing', 'FrisbeeCatch', 'PlayingGuitar', 'JugglingBalls', 'Fencing', 'Nunchucks', 'Billiards', 'TrampolineJumping', 'HighJump', 'StillRings', 'ApplyLipstick', 'BoxingSpeedBag', 'PizzaTossing', 'FieldHockeyPenalty', 'Surfing', 'CuttingInKitchen', 'BandMarching', 'Basketball', 'SoccerPenalty', 'Swing', 'Lunges', 'Haircut', 'SkateBoarding', 'PlayingSitar', 'PlayingTabla', 'BaseballPitch', 'JavelinThrow', 'JumpRope', 'PlayingDhol', 'WallPushups', 'HandstandWalking', 'WalkingWithDog', 'SumoWrestling', 'CricketBowling', 'CliffDiving', 'SkyDiving', 'Drumming', 'PlayingCello', 'PushUps', 'PommelHorse', 'HandstandPushups', 'HorseRace', 'WritingOnBoard', 'FloorGymnastics', 'PlayingDaf', 'HeadMassage', 'Kayaking', 'Skijet', 'BrushingTeeth', 'HammerThrow', 'Hammering', 'LongJump', 'MilitaryParade', 'UnevenBars', 'Diving', 'TennisSwing', 'Biking', 'BasketballDunk', 'BreastStroke', 'BlowDryHair', 'ShavingBeard', 'BalanceBeam', 'CleanAndJerk'] 
        elif 'hmdb51' in root_dir:
            classes = ['chew', 'clap', 'smoke', 'hug', 'drink', 'cartwheel', 'sit', 'fall_floor', 'wave', 'dive', 'pour', 'kick_ball', 'punch', 'kick', 'golf', 'turn', 'throw', 'ride_bike', 'sword', 'laugh', 'shake_hands', 'fencing', 'stand', 'eat', 'climb_stairs', 'climb', 'situp', 'ride_horse', 'sword_exercise', 'brush_hair', 'somersault', 'run', 'handstand', 'draw_sword', 'smile', 'jump', 'dribble', 'shoot_bow', 'flic_flac', 'pick', 'swing_baseball', 'shoot_gun', 'walk', 'pullup', 'hit', 'pushup', 'push', 'kiss', 'catch', 'talk', 'shoot_ball']
        else:
            raise NotImplementedError()
        if load_classes is None:
            self.categories = {cat:i for i, cat in enumerate(classes)}
        else:
            if keep_class_order:
                self.categories = {cat:i for i, cat in enumerate(classes) if cat in load_classes}
            else:
                self.categories = {cat:i for i, cat in enumerate(load_classes) if cat in classes}
        self.names = {i:cat for cat, i in self.categories.items()}

        self.data_labels = self.get_data_labels()
        self.data_labels = self.data_labels[self.data_labels[:, 0].argsort(kind='stable')]
        self.transform = transform
        self.mode = 'sequence'
        self.mean_subtract = mean_subtract
        self.transform_mode = transform_mode
        self.apply_prob = apply_prob
        np.random.shuffle(self.data_labels)

    def create_job(self, cat):
        raise NotImplementedError("This method must be overridden in the subclass.")

    def get_data_labels(self):
        jobs = [self.create_job(cat) for cat in self.categories.keys()]
        with Pool(initializer=tp_init_, initargs=(self,)) as pool:
            results = pool.starmap(self.load_single_class, jobs)
        arr = np.array(list(chain.from_iterable(results)))
        return arr

    def load_single_class(*args):
        raise NotImplementedError("This method must be overridden in the subclass.")

    def safe_capture(self, file, max_attempts=50):
        cap = cv2.VideoCapture(file)
        attempts = 1
        while not cap.isOpened() and attempts < max_attempts:
            cap = cv2.VideoCapture(file)
            attempts += 1
        if attempts > max_attempts:
            raise RuntimeError("Failed to open {} with cv2.VideoCapture after {} tries".format(file, attempts))
        return cap

    def load_videos(self, video_files, label, nframes=16):
        data_labels = []
        for file in video_files:
            with utils.stderr_suppress():
                if os.path.isfile(file):
                    cap = self.safe_capture(file)
                    tot = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                    cap.release()
                else:
                    tot = 0
            if self.split == 'train':
                if tot < nframes: continue
                start = random.randint(0, tot - nframes - 1)
                data_labels.append([file, start, label])
            elif self.split == 'test':
                data_labels.append([file, label]) # don't pick start location on validation
            else:
                raise ValueError("Only the 'train' and 'test' split are supported at this time; split name '{}' is not recognized.".format(self.split))
        return data_labels

    def get_classname(self, id):
        return self.names.get(id, -1)

    def __len__(self):
        length = len(self.data_labels)
        return length


    def __getitem__(self, idx):
        if self.split == 'train':
            path, start, label = self.data_labels[idx]
        elif self.split == 'test':
            path, label = self.data_labels[idx]
        else:
            raise ValueError("Only the 'train' and 'test' split are supported at this time; split name '{}' is not recognized.".format(self.split))

        clips = []
        metadata = []
        if not os.path.isfile(path):
            if self.split == 'train': warnings.warn(str(path) + " was not found during training.")
            return torch.Tensor(0), int(label), [(path, 0, label)]
        with utils.stderr_suppress():
            cap = self.safe_capture(path)
            if self.split == 'train':
                cap.set(cv2.CAP_PROP_POS_FRAMES, float(start))
                video_len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 # this is unreliable on corrupted videos
            ret = True
            frames = []
            while ret:
                ret, img = cap.read()
                try: 
                    img = cv2.resize(img, (self.load_width, self.load_height))
                except cv2.error:
                    break
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = F.to_tensor(img)
                frames.append(img)
                if self.split == 'train' and len(frames) == self.n_frames: break
            cap.release()
            clips = torch.Tensor(0)
            if len(frames) != 0:
                frames = torch.stack(frames, axis=0)
                frames = frames.permute(1, 0, 2, 3) #CTWH
                splits = list(torch.split(frames, self.n_frames, dim=1))
                while splits[-1].size(1) < self.n_frames:
                    splits[-1] = torch.cat([splits[-1]] * 2, dim=1)
                if splits[-1].size(1) != self.n_frames: 
                    splits[-1] = splits[-1][:, :self.n_frames, ...]
                for i, clip in enumerate(splits):
                    clip_meta = (path, self.n_frames * i, label)
                    if self.transform and random.random() < self.apply_prob:
                        splits[i] = self.transform(clip, meta=(path, i * self.n_frames))
                    if self.mean_subtract:
                        splits[i] = mean_sub(splits[i])
                    if self.transform and hasattr(self.transform, 'cache'):
                        last_key, last_val = list(self.transform.cache.items())[-1]
                        clip_meta += (last_val,)
                    metadata.append(list(clip_meta))
                if self.split == 'train':
                    frames = splits[0]
                    return frames, int(label), list(self.data_labels[idx])
                splits = list(takewhile(lambda x: x.ndim == 4, splits))
                if len(splits):
                    clips = torch.stack(splits, axis=0) #BCTWH
            else:
                assert len(metadata) == 0
                metadata.append([(path, 0, label)])
        return clips, int(label), metadata


class HMDB51Dataset(RawVideoFileDataset):
    def __init__(self, root_dir, splits_dir, split, load_width, load_height, split_id=1,  split_mode=None, **kwargs):
        self.split_id = split_id
        self.split_type = {'extra':0, 'train':1, 'test':2}[split]
        if split_mode is not None:
            split = split_mode
        super(HMDB51Dataset, self).__init__(root_dir, splits_dir, split, load_width, load_height, **kwargs)

    def create_job(self, cat):
        path = os.path.join(self.root_dir, cat)
        split_path = self.get_split(cat)
        cat_id = self.categories[cat]
        n_frames = self.n_frames
        return path, split_path, cat_id, n_frames

    def get_split(self, category):
        return os.path.join(self.splits_dir, "{}_test_split{}.txt".format(category, self.split_id))


    def load_single_class(*args):
        _, path, splitpath, label, nframes = args
        items = pd.read_csv(splitpath, delimiter=" ", encoding=None, header=None, usecols=[0,  1])
        split_items = items[items[1] == self.split_type][0].values
        files = [os.path.join(path, f) for f in split_items]
        class_labels = self.load_videos(files, label, nframes)
        return class_labels

class UCF101Dataset(RawVideoFileDataset):
    def __init__(self, root_dir, splits_dir, split, load_width, load_height, split_id=1, split_mode=None, **kwargs):
        self.df = pd.read_csv(os.path.join(splits_dir, "{}list0{}.txt".format(split, split_id)), sep=" ", header=None, usecols=[0])
        self.groups = self.df[0].groupby(self.get_prefix)
        if split_mode is not None:
            split = split_mode
        super(UCF101Dataset, self).__init__(root_dir, splits_dir, split, load_width, load_height, **kwargs)
        del self.df
        del self.groups

    def get_prefix(self, x):
        return self.df.iloc[x, 0].split("/")[0].lower()

    def create_job(self, cat):
        path = os.path.join(self.root_dir, cat)
        splitfiles = self.groups.get_group(cat.lower()).tolist()
        label = self.categories[cat]
        nframes = self.n_frames
        return path, splitfiles, label, nframes

    def load_single_class(*args):
        _, path, splitfiles, label, nframes = args
        files = [os.path.join(self.root_dir, f) for f in splitfiles]
        return self.load_videos(files, label, nframes)
