from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.transforms as image_transforms
import torchvision.transforms.functional as F
import torchvideo 
import torch
import os
import inspect
import pickle

import utils
from datasets.raw_video_dataset import *
from datasets.frame_dataset import *

home = os.path.expanduser("~")
preview_folder = os.path.join(home, 'video-augmentation-experiments/src/previews')
def clear_cache():
    if os.path.isdir(preview_folder):
        for filename in os.listdir(preview_folder):
            fp = os.path.join(preview_folder, filename)
            try:
                if os.path.isfile(fp) or os.path.islink(fp):
                    os.unlink(fp)
            except Exception as e:
                print('Failed to delete {}: {}'.format(fp, e))

supported_datasets = ['scale', 'transition', 'conversation', 'hmdb51', 'ucf101']

def action_recognition_kwargs(args):
    return {'max_frames':args.max_frames, 'transform':get_transform(args), 
            'transform_mode':args.transform_mode, 'mean_subtract':args.mean_subtract, 
            'apply_prob':args.apply_prob, 'load_classes':args.load_classes, 
            'keep_class_order':args.keep_class_order}

def get_val(kwds, args):
    kwds['transform'] = get_transform(args, mode='val')
    return kwds

def load_data(args):
    fpath = os.path.join(args.base_path, args.img_folder)
    if args.dataset == 'scale':
        with open(os.path.join(args.base_path, args.label_file), "rb") as f:
            info = pickle.load(f)
        train = ShotScaleDataset(fpath, info, args.max_frames, args.load_width, 
            args.load_height, mode=args.mode, transform=transform, 
            load_first=args.load_first, recurrent=args.recurrent, 
            transform_mode=args.transform_mode, normalize=args.normalize, 
            mean_subtract=args.mean_subtract, apply_prob=args.apply_prob)
        val = ShotScaleDataset(fpath, info, args.max_frames, args.load_width, 
            args.load_height, mode=args.mode, transform=val_transform, 
            load_first=args.load_first, recurrent=args.recurrent, 
            transform_mode=args.transform_mode, normalize=args.normalize, 
            mean_subtract=args.mean_subtract)
    elif args.dataset == 'transition':
        label_fpath = os.path.join(args.base_path, args.img_folder, "data/val.txt")
        train = ShotTransitionDataset(os.path.join(fpath, "images"), label_fpath, 
            args.load_width, args.load_height, mode=args.mode, transform=transform, 
            load_first=args.load_first, recurrent=args.recurrent, max_frames=args.max_frames, 
            transform_mode=args.transform_mode, normalize=args.normalize, 
            mean_subtract=args.mean_subtract, apply_prob=args.apply_prob)
        val = ShotTransitionDataset(os.path.join(fpath, "images"), label_fpath, 
            args.load_width, args.load_height, mode=args.mode, transform=val_transform, 
            load_first=args.load_first, recurrent=args.recurrent, max_frames=args.max_frames, 
            transform_mode=args.transform_mode, normalize=args.normalize, 
            mean_subtract=args.mean_subtract)
    elif args.dataset == 'conversation':
        label_fpath = os.path.join(args.base_path, args.img_folder, "data/train.txt")
        train = ConversationDataset(os.path.join(fpath, "images"), label_fpath, 
            args.load_width, args.load_height, max_frames=args.max_frames, mode=args.mode, 
            transform=transform, recurrent=args.recurrent, transform_mode=args.transform_mode, 
            normalize=args.normalize, mean_subtract=args.mean_subtract, apply_prob=args.apply_prob)
        val = ConversationDataset(os.path.join(fpath, "images"), label_fpath, 
            args.load_width, args.load_height, max_frames=args.max_frames, mode=args.mode, 
            transform=val_transform, recurrent=args.recurrent, transform_mode=args.transform_mode, 
            normalize=args.normalize, mean_subtract=args.mean_subtract)
    elif args.dataset == 'hmdb51':
        dataset_args = action_recognition_kwargs(args)
        train = HMDB51Dataset(os.path.join(fpath, "hmdb51"), 
            os.path.join(fpath, "train_test_splits"), 'train', args.load_width, 
            args.load_height, **dataset_args)
        val = HMDB51Dataset(os.path.join(fpath, "hmdb51"), 
            os.path.join(fpath, "train_test_splits"), 'test', args.load_width, 
            args.load_height, **get_val(dataset_args, args))
        args.mode = 'sequence' # image not supported
    elif args.dataset == 'ucf101':
        dataset_args = action_recognition_kwargs(args)
        train = UCF101Dataset(os.path.join(fpath, "videos"), os.path.join(fpath, "data"),
            'train', args.load_width, args.load_height, **dataset_args)
        val = UCF101Dataset(os.path.join(fpath, "videos"), os.path.join(fpath, "data"),
            'test', args.load_width, args.load_height, **get_val(dataset_args, args))
        args.mode = 'sequence'
    else:
        raise ValueError("Dataset '{}' not recognized; must be one of.".format(args.dataset), supported_datasets)
    return train, val

def init_transform(transform_fn, args):
    if args.transform_mode == 'all': args.keypoints = 1
    if args.transform_mode == 'frame': args.keypoints = 16
    kwds = vars(args)
    legal_args = inspect.getargspec(transform_fn).args
    kwds = dict(filter(lambda x: x[0] in legal_args, kwds.items()))
    return transform_fn(**kwds)

def get_transform(args, mode='train'):
    transform_arg = None
    if mode == 'train':
        transform_arg = args.transform
    elif mode == 'val':
        transform_arg = args.val_transform
    if transform_arg:
        if len(transform_arg) > 1:
            transform = torchvideo.Compose([init_transform(getattr(torchvideo, t), args) for t in transform_arg])
        else:
            transform = init_transform(getattr(torchvideo, transform_arg[0]), args)
    else:
        transform = None
    print("Transformation function:", transform)
    return transform

def build_dataloader(data, args):
    train, val = data
    train_loader = DataLoader(train, batch_size=args.bs, num_workers=args.num_workers, shuffle=True)
    val_loader = DataLoader(val, batch_size=1, num_workers=args.num_workers, shuffle=False, collate_fn=validation_collator) 
    train_size, val_size = len(train), len(val)
    return train_loader, val_loader, train_size, val_size

def validation_collator(batch):
    batch = batch[0] # ONLY WORKS BECAUSE BATCH SIZE MUST BE 1
    X, y, meta = batch
    meta = batch[2]
    res = list(zip(*meta))
    return X.unsqueeze(0), torch.LongTensor([y]), res


def cache_batch(batch, labels=None, path_fmt="previews/vid{}", vid_ext=".mp4", verbose=True, annotate=True):
    if annotate: 
        assert labels is not None, "Label tensor must be provided when annotate=True"  
        assert len(batch) == len(labels), "Data and label tensors must have same length"
    for i, vid in enumerate(torch.split(batch, 1, dim=0)): 
        if annotate: 
            vid = utils.annotate_video(vid.squeeze(0), labels[i])
        else:
            vid = vid.squeeze(0)
        utils.save_tensor_to_video(vid, path_fmt.format(i), vid_ext=vid_ext) 
        if verbose: print("Cached data point", i+1)


if __name__ == '__main__':
    """
        A small testing harness that attempts to load a batch of the specified dataset.
    """
    from options import get_args, prettyprint_args
    import time

    args = get_args()
    print(prettyprint_args(args))
    data = load_data(args)
    dataloader, _, _, _ = build_dataloader(data, args)
    if len(data) == 2:
        data, _ = data
    print("Instance of {} has {} examples".format(data.__class__, data.__len__()))
    label_col = -1 # fix asap
    for i in range(args.n_classes):
        print("{} examples of class {}".format(np.count_nonzero(data.data_labels[:,label_col].astype(int) == i), i))
    print("Test-loading a batch of size", args.bs)
    start = time.time()
    test_seq_batch, labels = next(iter(dataloader))
    print("Loaded data:", test_seq_batch.shape)
    if data.mode == 'sequence':
        non_channel_dims = [(0, 2, 3, 4), (0, 1, 3, 4)][args.recurrent]
        print("Batch channel means:", torch.mean(test_seq_batch, dim=non_channel_dims))
        print("Batch channel variances:", torch.var(test_seq_batch, dim=non_channel_dims))
    print("Labels:", labels)
    end = time.time()
    print("Took {:.4f}s, {:.4f}s/it".format(end - start, (end - start) / args.bs))
    if not args.lightweight:
        print("Caching batch...")
        clear_cache()
        cache_batch(test_seq_batch)

