from collections import defaultdict
import numpy as np
import pickle
import os
from samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
from sklearn.neighbors import NearestNeighbors
import time
import torch
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
from tqdm import tqdm
import utils
from utils import Flatten, print_or_log 


def collate_fn(batch):
    batch = [(d[0], d[1], d[2], d[3], d[4]) for d in batch if d is not None and d[1].shape == (1, 40, 99)]
    if len(batch) == 0:
        return None
    else:
        return default_collate(batch)


def save_pickle(obj, name):
    with open(name, 'wb') as handle:
        print("Dumping data as pkl file", flush=True)
        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)


def load_pickle(pkl_path):
    if os.path.exists(pkl_path):
        print(f"Loading pickle file: {pkl_path}", flush=True)
        with open(pkl_path, 'rb') as handle:
            result = pickle.load(handle)
            return result
    else:
        raise


def get_model(args, get_video_encoder_only=True, logger=None):
    
    # Load model
    model = utils.load_model(
        model_name=args.model, 
        vid_base_arch=args.vid_base_arch, 
        aud_base_arch=args.aud_base_arch, 
        pretrained=args.pretrained,
        num_classes=256,
        norm_feat=False,
        use_mlp=args.use_mlp,
        mlptype=args.mlptype,
        headcount=args.headcount
    )

    # Load model weights
    start = time.time()
    weight_path_type = type(args.weights_path)
    weight_path_not_none = args.weights_path != 'None' if weight_path_type == str else args.weights_path is not None
    if weight_path_not_none:
        print_or_log("Loading model weights", logger=logger)
        if os.path.exists(args.weights_path):
            ckpt_dict = torch.load(args.weights_path)
            model_weights = ckpt_dict["model"]
            args.ckpt_epoch = ckpt_dict['epoch']
            print(f"Epoch checkpoint: {args.ckpt_epoch}", flush=True)
            utils.load_model_parameters(model, model_weights)
    print_or_log(f"Time to load model weights: {time.time() - start}", logger=logger)

    # Put model in eval mode
    model.eval()

    # Get video encoder for video-only retrieval
    if get_video_encoder_only:
        print_or_log("Getting video encoder only", logger=logger)
        # Extract right layer of model
        if args.model not in ['r2plus1d_18', 'r2plus1d_34', 'r2plus1d_50']:
            model = model.video_network.base
        if args.pool_op == 'max': 
            pool = torch.nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2))
        elif args.pool_op == 'avg': 
            pool = torch.nn.AvgPool3d((2, 2, 2), stride=(2, 2, 2))
        else:
            assert("Only 'max' and 'avg' pool operations allowed")

        # Set up model
        if args.vid_base_arch in ['r2plus1d_18', 'r2plus1d_34', 'r2plus1d_50'] and args.flatten:
            model = torch.nn.Sequential(*[
                model.stem,
                model.layer1,
                model.layer2,
                model.layer3,
                model.layer4,
                pool,
                Flatten(),
            ])

    if torch.cuda.is_available() and args.use_cuda:
        print_or_log("Transferring model to data parallel", logger=logger)
        model = model.cuda()
        model = torch.nn.DataParallel(model)
    return model


def init(args, get_video_encoder_only=True, logger=None):

    # Get transforms
    transform_train, transform_test, subsample = utils.get_transforms(args)

    # Loading Train data
    print_or_log("Loading training data", logger=logger)
    st = time.time()
    cache_path = utils._get_cache_path(args.dataset, 'train', args.fold, args.clip_len, args.steps_bet_clips)
    if args.cache_dataset and os.path.exists(cache_path):
        print_or_log(f"Loading dataset_train from {cache_path}", logger=logger)
        dataset = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        dataset = utils.load_dataset(
            dataset_name=args.dataset,
            fold=args.fold,
            mode='train',
            frames_per_clip=args.clip_len,
            transforms=transform_train,
            subsample=subsample
        )
        if args.cache_dataset:
            print_or_log(f"Saving dataset_train to {cache_path}", logger=logger)
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset), cache_path)
    print_or_log(f"Took {time.time() - st}", logger=logger)

    # Loading Validation data
    print_or_log("Loading validation data", logger=logger)
    cache_path = utils._get_cache_path(args.dataset, 'val', args.fold, args.clip_len, args.steps_bet_clips)
    if args.cache_dataset and os.path.exists(cache_path):
        print_or_log(f"Loading dataset_test from {cache_path}", logger=logger)
        dataset_test = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        dataset_test = utils.load_dataset(
            dataset_name=args.dataset,
            fold=args.fold,
            mode='val',
            frames_per_clip=args.clip_len,
            transforms=transform_test,
            subsample=subsample
        )
        if args.cache_dataset:
            print_or_log(f"Saving dataset_test to {cache_path}", logger=logger)
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test), cache_path)

    model = get_model(args, get_video_encoder_only=get_video_encoder_only, logger=logger)
    return model, dataset, dataset_test


def get_features(args, dataset, model, get_audio=False, logger=None, mode='train', print_freq=250, pretext=None):

    # clear cache at beginning
    torch.cuda.empty_cache()

    # dtype
    dtype = np.float64
    N = len(dataset)
    print(f"Size of DS: {N}")

    if args.sampling == 'random':
        sampler = RandomClipSampler(dataset.video_clips, args.train_clips_per_video)
    else:
        sampler = UniformClipSampler(dataset.video_clips, args.train_clips_per_video)

    # we need a data loader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=collate_fn if get_audio else None,
        drop_last=False
    )
    print(f"Size of Dataloader: {len(dataloader)}")

    # 1. aggregate inputs:
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            # Get data
            if get_audio:
                video, audio, label, _, video_idx = batch
            else:
                video, label, _, video_idx = batch

            # Move to GPU
            if torch.cuda.is_available() and args.use_cuda:
                video = video.cuda(non_blocking=True)
                label = label.cuda(non_blocking=True)
                video_idx = video_idx.cuda(non_blocking=True)
                if get_audio:
                    audio = audio.cuda(non_blocking=True)

            # Forward pass
            if get_audio:
                feat_v, feat_a = model(video, audio)
            else:
                feat_v = model(video)
                if pretext.lower() == 'aot':
                    feat_v2 = model(video.flip(2))

            if batch_idx == 0:
                if get_audio:
                    print_or_log((batch_idx, video.shape, audio.shape, feat_v.shape, label.shape), logger=logger)
                else:
                    print_or_log((batch_idx, video.shape, feat_v.shape, label.shape), logger=logger)

            feat_v = feat_v.cpu()
            video_idx = video_idx.cpu()
            label = label.cpu()
            all_feat_v = feat_v.numpy()
            if pretext.lower() == 'aot':
                all_feat_v2 = feat_v2.numpy()
            all_indices = video_idx.numpy().astype(np.int32)
            all_labels = label.numpy().astype(np.int32)
            if get_audio:
                feat_a = feat_a.cpu()
                all_feat_a = feat_a.numpy()

            if batch_idx == 0:
                K = feat_v.size(1)
                PS_v_np = [] 
                indices_np = [] 
                labels_np = []
                if get_audio:
                    PS_a_np = []

            # fill in arrays on main node
            PS_v_np.append(all_feat_v)
            if pretext.lower() == 'aot':
                PS_v_np.append(all_feat_v2)
            indices_np.append(all_indices) 
            labels_np.append(all_labels) 
            if get_audio:
                PS_a_np.append(all_feat_a)

            print(f'{batch_idx} / {len(dataloader)}', end='\r')
        print_or_log("Done collecting features", logger=logger)

        # Concat numpy errors
        PS_v = np.concatenate(PS_v_np, axis=0) 
        indices = np.concatenate(indices_np, axis=0) 
        labels = np.concatenate(labels_np, axis=0) 
        if get_audio:
            PS_a = np.concatenate(PS_a_np, axis=0)

        if args.save_pkl:
            if pretext is None:
                pretext = f"{args.vid_base_arch}_{args.dataset}_{args.train_clips_per_video}_{mode}"
            if not os.path.exists(args.output_dir):
                os.makedirs(args.output_dir)
            save_pickle(PS_v, os.path.join(args.output_dir, f"{pretext}_feats.pkl"))
            save_pickle(indices, os.path.join(args.output_dir, f"{pretext}_indices.pkl"))
            save_pickle(labels, os.path.join(args.output_dir, f"{pretext}_labels.pkl"))
            if get_audio:
                save_pickle(PS_a, os.path.join(args.output_dir, f"{pretext}_feats_aud.pkl"))

        if get_audio:
            return PS_v, PS_a, indices, labels
        else:
            return PS_v, indices, labels


def load_or_get_features(args, dataset, model, get_audio=False, logger=None, mode='train', pretext=None):
    # Get train features
    if pretext is None:
        pretext = f"{args.vid_base_arch}_{args.dataset}_{args.train_clips_per_video}_{mode}"
    if args.use_cache_feats:
        try: 
            features = load_pickle(
                os.path.join(args.output_dir, f"{pretext}_feats.pkl")
            )
            vid_indices = load_pickle(
                os.path.join(args.output_dir, f"{pretext}_indices.pkl")
            )
            labels = load_pickle(
                os.path.join(args.output_dir, f"{pretext}_labels.pkl")
            )
            if get_audio:
                aud_features = load_pickle(
                    os.path.join(args.output_dir, f"{pretext}_feats_aud.pkl")
                )
                return features, aud_features, vid_indices, labels
            else:
                return features, vid_indices, labels
        except: 
            if get_audio:
                features, aud_features, vid_indices, labels = get_features(
                    args, dataset, model, get_audio=get_audio, logger=logger, mode=mode, pretext=pretext
                )   
                return features, aud_features, vid_indices, labels
            else:
                features, vid_indices, labels = get_features(
                    args, dataset, model, get_audio=get_audio, logger=logger, mode=mode, pretext=pretext
                )
                return features, vid_indices, labels
    else:
        if get_audio:
            features, aud_features, vid_indices, labels = get_features(
                args, dataset, model, get_audio=get_audio, logger=logger, mode=mode, pretext=pretext
            )   
            return features, aud_features, vid_indices, labels
        else:
            features, vid_indices, labels = get_features(
                args, dataset, model, get_audio=get_audio, logger=logger, mode=mode, pretext=pretext
            )
            return features, vid_indices, labels


def average_features(args, features, vid_indices, labels, get_audio=False, aud_features=None, logger=None):
    feat_dict = defaultdict(list)
    label_dict = defaultdict(list)
    if get_audio and aud_features is not None:
        aud_feat_dict = defaultdict(list)
    print(f"Total Number of features: {len(features)}")
    for i in range(len(features)):
        if args.norm_feats:
            v = features[i]
            feat = v / np.sqrt(np.sum(v**2))
            if get_audio and aud_features is not None:
                a = aud_features[i]
                feat_a = a / np.sqrt(np.sum(a**2))
        else:
            feat = features[i]
            if get_audio and aud_features is not None:
                feat_a = aud_features[i]
        label = labels[i]
        vid_idx = vid_indices[i]
        feat_dict[vid_idx].append(feat)
        label_dict[vid_idx].append(label)
        if get_audio and aud_features is not None:
            aud_feat_dict[vid_idx].append(feat_a)
        print(f'{i} / {len(features)}', end='\r')

    avg_features, avg_vid_indices, avg_labels = [], [], []
    if get_audio and aud_features is not None:
        avg_features_aud = []
    num_features = 0
    for vid_idx in feat_dict:
        stcked_feats = np.stack(feat_dict[vid_idx])
        feat = np.mean(stcked_feats, axis=0)
        vid_ix_feat_len = stcked_feats.shape[0]
        num_features += vid_ix_feat_len
        if get_audio and aud_features is not None:
            feat_a = np.mean(np.stack(aud_feat_dict[vid_idx]), axis=0)
        label = label_dict[vid_idx][0]
        avg_features.append(feat)
        avg_vid_indices.append(vid_idx)
        avg_labels.append(label)
        if get_audio and aud_features is not None:
            avg_features_aud.append(feat_a)
    avg_features = np.stack(avg_features, axis=0)
    avg_indices = np.stack(avg_vid_indices, axis=0)
    avg_labels = np.stack(avg_labels, axis=0)
    if get_audio and aud_features is not None:
        avg_features_aud = np.stack(avg_features_aud, axis=0)
    print_or_log(f"{avg_features.shape}, {avg_indices.shape}, {avg_labels.shape}", logger=logger)
    if get_audio and aud_features is not None:
        return avg_features, avg_features_aud, avg_vid_indices, avg_labels
    else:
        return avg_features, avg_vid_indices, avg_labels


def retrieval(
    train_features, 
    train_labels,
    train_vid_indices,
    val_features, 
    val_labels, 
    val_vid_indices, 
    train_aud_features=None, 
    val_aud_features=None, 
    task='v-v'
):

    assert task in ['v-a', 'a-v', 'v-v', 'a-a']
    if task in ['v-a', 'a-v', 'a-a']:
        assert(train_aud_features is not None)
        assert(val_aud_features is not None)

    if task == 'v-v':
        feat_val = val_features
        feat_train = train_features
    elif task == 'v-a':
        feat_val = val_features
        feat_train = train_aud_features
    elif task == 'a-v':
        feat_val = val_aud_features
        feat_train = train_features
    elif task == 'a-a':
        feat_val = val_aud_features
        feat_train = train_aud_features

    # Create 
    neigh = NearestNeighbors(50)
    neigh.fit(feat_train)
    recall_dict = defaultdict(list)
    retrieval_dict = {}
    for i in range(len(feat_val)):
        feat = np.expand_dims(feat_val[i], 0)
        vid_idx = val_vid_indices[i]
        vid_label = val_labels[i]
        retrieval_dict[vid_idx] = {
            'label': vid_label,
            'recal_acc': {
                '1': 0, '5': 0, '10': 0, '20': 0, '50': 0
            },
            'neighbors': {
                '1': [], '5':[], '10': [], '20': [], '50': []
            }
        }
        for recall_treshold in [1, 5, 10, 20, 50]:
            neighbors = neigh.kneighbors(feat, recall_treshold)
            neighbor_indices = neighbors[1]
            neighbor_indices = neighbor_indices.flatten()
            neighbor_labels = set([train_labels[vid_index] for vid_index in neighbor_indices])
            recall_value = 100 if vid_label in neighbor_labels else 0
            acc_value = len([1 for neigh_label in neighbor_labels if neigh_label == vid_label]) / float(len(neighbor_labels))
            retrieval_dict[vid_idx]['recal_acc'][str(recall_treshold)] = acc_value
            retrieval_dict[vid_idx]['neighbors'][str(recall_treshold)] = neighbor_indices
            recall_dict[recall_treshold].append(recall_value)
        print(f'{i} / {len(feat_val)}', end='\r')

    # Calculate mean recall values
    for recall_treshold in [1, 5, 10, 20, 50]:
        mean_recall = np.mean(recall_dict[recall_treshold]) 
        print(f"{task}: Recall @ {recall_treshold}: {mean_recall}")
    return retrieval_dict


def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
            'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='Video Representation Learning')
    parser.register('type', 'bool', str2bool)

    ### Retrieval params
    parser.add_argument(
        '--use-cache-feats', 
        default='False',
        type='bool', 
        help='use cache features'
    )
    parser.add_argument(
        '--save-pkl', 
        default='False',
        type='bool', 
        help='save pickled feats'
    )
    parser.add_argument(
        '--avg-feats', 
        default='True',
        type='bool', 
        help='Average features of video'
    )
    parser.add_argument(
        '--norm-feats', 
        default='True',
        type='bool', 
        help='L2 normalize features of video'
    )
    parser.add_argument(
        '--pool-op', 
        default='max',
        type=str, 
        choices=['max', 'avg'],
        help='Type of pooling operation: [max, avg]'
    )
    parser.add_argument(
        '--sampling', 
        default='uniform',
        type=str, 
        choices=['uniform', 'random'],
        help='Type of sampling operation: [uniform, random]'
    )
    parser.add_argument(
        '--get-audio', 
        default='False',
        type='bool', 
        help='Get audio features'
    )

    ### Dataset params
    parser.add_argument(
        '--dataset', 
        default='hmdb51', 
        help='name of dataset'
    )
    parser.add_argument(
        '--batch-size', 
        default=64, 
        help='Size of batch'
    )
    parser.add_argument(
        '--fold', 
        default='1', 
        type=str,
        help='name of dataset'
    )
    parser.add_argument(
        '--clip-len', 
        default=30, 
        type=int, 
        metavar='N',
        help='number of frames per clip'
    )
    parser.add_argument(
        '--augtype',
        default=1,
        type=int, 
        help='augmentation type (default: 1)'
    )
    parser.add_argument(
        '--use-scale-jittering',
        default='False',
        type='bool', 
        help='scale jittering as augmentations'
    )
    parser.add_argument(
        '--steps-bet-clips', 
        default=1, 
        type=int, 
        metavar='N',
        help='number of steps between clips in video'
    )
    parser.add_argument(
        '--train-clips-per-video', 
        default=10, 
        type=int, 
        metavar='N',
        help='maximum number of clips per video to consider for training'
    )
    parser.add_argument(
        '--val-clips-per-video', 
        default=10, 
        type=int, 
        metavar='N',
        help='maximum number of clips per video to consider for testing'
    )
    parser.add_argument(
        "--cache-dataset",
        type='bool', 
        default='True',
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
    )
    parser.add_argument(
        '-j', '--workers', 
        default=10, 
        type=int, 
        metavar='N',
        help='number of data loading workers (default: 16)'
    )
    parser.add_argument(
        '--colorjitter',
        default='False',
        type='bool',
        help='Apply random color jitter'
    )

    ### MODEL
    parser.add_argument(
        '--model', 
        default='av_gdt', 
        help='model',
        choices=['r2plus1d_18', 'av_gdt']
    )
    parser.add_argument(
        '--weights-path', 
        default='', 
        help='Path to weights file',
    )
    parser.add_argument(
        '--vid-base-arch', 
        default='r2plus1d_18', 
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 's3d', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch', 
        default='resnet18', 
        help='Audio Base Arch for A-V model',
        choices=['resnet18', 'vgg_audio', 'resnet34', 'resnet50']
    )
    parser.add_argument(
        '--pretrained',
        type='bool', 
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        '--use-mlp', 
        default='True', 
        type='bool', 
        help='Use MLP projection head'
    )
    parser.add_argument(
        '--mlptype',
        default=1,
        type=int,
        help='MLP type (default: 0)'
    )
    parser.add_argument(
        '--headcount',
        type=int,
        default=1,
        help='how many heads each modality has'
    )
    parser.add_argument(
        "--flatten",
        type='bool',
        default='True',
        help="Flatten embeddings",
    )

    # distributed training parameters
    parser.add_argument(
        '--output-dir', 
        default='.', 
        help='path where to save'
    )
    parser.add_argument(
        '--use-cuda',
        type='bool', 
        default='True',
        help="Use CUDA",
    )

    args = parser.parse_args()
    return args
