
from collections import defaultdict
from datasets.UCF101 import UCF101
from datasets.HMDB51 import HMDB51
import numpy as np
import os
import random
import pickle
from samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
import sklearn
from sklearn.neighbors import KNeighborsClassifier
from svm import train_svm
import time
import torch
import utils
from retrieval_utils import average_features, load_or_get_features, get_model
from utils import Flatten, print_or_log

SEED=random.randint(0, 50000)
torch.manual_seed(SEED)


def load_dataset(args, transforms=None, mode=None, num_shots=-1, logger=None):
    
    cache_path = utils._get_cache_path(args.dataset, mode, str(args.fold) + str(num_shots), args.clip_len, args.steps_bet_clips)
    if args.cache_dataset and os.path.exists(cache_path):
        print_or_log(f"Loading {args.dataset} dataset {mode} from {cache_path}", logger=logger)
        dataset = torch.load(cache_path)
        dataset.transform = transforms
    else:
        if args.dataset == 'ucf101':
            dataset = UCF101(
                frames_per_clip=args.clip_len,
                step_between_clips=1,
                transform=transforms,
                fold=args.fold,
                subsample=False,
                train=True if mode == 'train' else False,
                num_shots=num_shots
            )
        elif args.dataset == 'hmdb51':
            dataset = HMDB51(
                frames_per_clip=args.clip_len,
                step_between_clips=1,
                transform=transforms,
                fold=args.fold,
                subsample=False,
                train=True if mode == 'train' else False,
                num_shots=num_shots
            )
        else:
            assert("Dataset is not supported")
        if args.cache_dataset:
            print_or_log(f"Saving dataset_train to {cache_path}", logger=logger)
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset), cache_path)
    
    return dataset
    


def init(args, logger=None):

    # Get transforms
    transform_train, transform_test, subsample = utils.get_transforms(args)

    # Loading Train data
    print_or_log("Loading training data", logger=logger)
    dataset = load_dataset(args, transforms=transform_train, mode='train', num_shots=args.num_shots)

    # Loading Validation data
    print_or_log("Loading validation data", logger=logger)
    dataset_test = load_dataset(args, transforms=transform_train, mode='val', num_shots=-1)

    # Load model
    model = get_model(args, get_video_encoder_only=True)
    return model, dataset, dataset_test


def main(args, logger=None):

    # init datasets and models
    model, dataset, dataset_test = init(args, logger=logger)

    # Get train features
    pretext_train = f"{args.vid_base_arch}_{args.dataset}_{args.train_clips_per_video}_train_shots{args.num_shots}_seed_{SEED}"
    pretext_val = f"{args.vid_base_arch}_{args.dataset}_{args.train_clips_per_video}_val_shots{args.num_shots}_seed_{SEED}"
    train_features, train_vid_indices, train_labels = load_or_get_features(
        args, dataset, model, logger=logger, mode='train', pretext=pretext_train)
    print(train_features.shape, train_vid_indices.shape, train_labels.shape)

    val_features, val_vid_indices, val_labels = load_or_get_features(
            args, dataset_test, model, logger=logger, mode='val', pretext=pretext_val)
    print(val_features.shape, val_vid_indices.shape, val_labels.shape)

    # Get average features
    if args.avg_feats:
        train_features, train_vid_indices, train_labels = average_features(
            args, train_features, train_vid_indices, train_labels, logger=logger)
        val_features, val_vid_indices, val_labels = average_features(
            args, val_features, val_vid_indices, val_labels, logger=logger)

    # Create dict of np arrays: train
    train_svm_dict = {}
    train_svm_dict['features'] = np.copy(train_features)
    train_svm_dict['labels'] = np.copy(train_labels)
    train_svm_dict['indices'] = np.copy(train_vid_indices)

    # Create dict of np arrays: val
    val_svm_dict = {}
    val_svm_dict['features'] = np.copy(val_features)
    val_svm_dict['labels'] = np.copy(val_labels)
    val_svm_dict['indices'] = np.copy(val_vid_indices)
    
    if args.classifier == 'knn':
        neigh = KNeighborsClassifier(n_neighbors=args.k)
        neigh.fit(train_svm_dict['features'], train_svm_dict['labels'])
        acc = neigh.score(val_svm_dict['features'], val_svm_dict['labels'])
        print_or_log(f'Fold accuracy {args.dataset}-{args.num_shots} is: ' + str(acc), logger=logger)
    elif args.classifier == 'svm':
        # Train SVM
        clf, train_metrics, valid_metrics, test_metrics = train_svm(
            train_data=train_svm_dict,
            valid_data=None,
            test_data=val_svm_dict,
            model_save_dir='.',
            C=1,
            fit_intercept=False,
            max_iterations=1000,
            kernel='linear',
            num_classes=args.val_num_classes,
            val_indices=val_svm_dict['indices'],
            train_classes=train_svm_dict['labels']
        )
        print('Fold accuracy is: ' + str(test_metrics['accuracy']), flush=True)
    elif args.classifier == 'sgd':
        clf = sklearn.linear_model.SGDClassifier(
            loss='log',
            max_iter=1000
        )
        clf.fit(train_svm_dict['features'], train_svm_dict['labels'])
        acc = clf.score(val_svm_dict['features'], val_svm_dict['labels'])
        print_or_log(f'Fold accuracy {args.dataset}-{args.num_shots} is: ' + str(acc), logger=logger)
    else:
        assert("Classifier is not supported")


def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
            'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='Few shot Learning')
    parser.register('type', 'bool', str2bool)

    ### Few Shot Params
    parser.add_argument(
        '--k', 
        default=1,
        type=int, 
        help='Num neighbors'
    )
    parser.add_argument(
        '--batch-size', 
        default=10, 
        help='Size of batch'
    )
    parser.add_argument(
        '--classifier', 
        default='knn',
        type=str, 
        choices=['svm', 'knn', 'sgd'],
        help='Type of classifier: [svm, knn, sgd]'
    )
    parser.add_argument(
        '--num-shots', 
        default=1, 
        type=int, 
        help='Number of shots / class'
    )
    parser.add_argument(
        '--train-num-classes', 
        default=31, 
        type=int, 
        help='Number of shots / class'
    )
    parser.add_argument(
        '--val-num-classes', 
        default=10, 
        type=int, 
        help='Number of shots / class'
    )
    parser.add_argument(
        '--save-pkl', 
        default='False',
        type='bool', 
        help='save pickled feats'
    )

    ### Retrieval params
    parser.add_argument(
        '--use-cache-feats', 
        default='False',
        type='bool', 
        help='use cache features'
    )
    parser.add_argument(
        '--norm-feats', 
        default='True',
        type='bool', 
        help='L2 normalize features of video'
    )
    parser.add_argument(
        '--avg-feats', 
        default='True',
        type='bool', 
        help='Average features of video'
    )
    parser.add_argument(
        '--pool-op', 
        default='max',
        type=str, 
        choices=['max', 'avg'],
        help='Type of pooling operation: [max, avg]'
    )
    parser.add_argument(
        "--flatten",
        type='bool',
        default='True',
        help="Flatten embeddings",
    )
    parser.add_argument(
        '--sampling', 
        default='uniform',
        type=str, 
        choices=['uniform', 'random'],
        help='Type of sampling operation: [uniform, random]'
    )

    ### Dataset params
    parser.add_argument(
        '--dataset', 
        default='ucf101', 
        help='name of dataset'
    )
    parser.add_argument(
        '--fold', 
        default='1', 
        type=str,
        help='name of dataset'
    )
    parser.add_argument(
        '--clip-len', 
        default=30, 
        type=int, 
        metavar='N',
        help='number of frames per clip'
    )
    parser.add_argument(
        '--augtype',
        default=1,
        type=int, 
        help='augmentation type (default: 1)'
    )
    parser.add_argument(
        '--use-scale-jittering',
        default='False',
        type='bool', 
        help='scale jittering as augmentations'
    )
    parser.add_argument(
        '--steps-bet-clips', 
        default=1, 
        type=int, 
        metavar='N',
        help='number of steps between clips in video'
    )
    parser.add_argument(
        '--train-clips-per-video', 
        default=10, 
        type=int, 
        metavar='N',
        help='maximum number of clips per video to consider for training'
    )
    parser.add_argument(
        '--val-clips-per-video', 
        default=10, 
        type=int, 
        metavar='N',
        help='maximum number of clips per video to consider for testing'
    )
    parser.add_argument(
        "--cache-dataset",
        type='bool', 
        default='True',
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
    )
    parser.add_argument(
        '-j', '--workers', 
        default=10, 
        type=int, 
        metavar='N',
        help='number of data loading workers (default: 16)'
    )
    parser.add_argument(
        '--colorjitter',
        default='False',
        type='bool',
        help='Apply random color jitter'
    )

    ### MODEL
    parser.add_argument(
        '--model', 
        default='av_gdt', 
        help='model',
        choices=['r2plus1d_18','av_gdt']
    )
    parser.add_argument(
        '--weights-path', 
        default='', 
        help='Path to weights file',
    )
    parser.add_argument(
        '--vid-base-arch', 
        default='r2plus1d_18', 
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 's3d', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch', 
        default='resnet9', 
        help='Audio Base Arch for A-V model',
        choices=['resnet18', 'vgg_audio', 'resnet34', 'resnet50']
    )
    parser.add_argument(
        '--pretrained',
        type='bool', 
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        '--use-mlp', 
        default='True', 
        type='bool', 
        help='Use MLP projection head'
    )
    parser.add_argument(
        '--mlptype',
        default=1,
        type=int,
        help='MLP type (default: 0)'
    )
    parser.add_argument(
        '--headcount',
        type=int,
        default=1,
        help='how many heads each modality has'
    )

    parser.add_argument(
        '--use-cuda',
        type='bool', 
        default='True',
        help="Use CUDA",
    )
    parser.add_argument(
        '--distributed', 
        type='bool', 
        default='False',
        help="ddp mode",
    )
    parser.add_argument(
        '--global-rank',
        type=int,
        default=0,
        help='master rank of GPU'
    )
    parser.add_argument(
        '--world-size',
        type=int,
        default=1,
        help='world size'
    )
    parser.add_argument(
        '--output-dir', 
        default='.', 
        help='path where to save'
    )
    parser.add_argument(
        '--filename', 
        default='logger.out', 
        help='path where to save'
    )
    parser.add_argument(
        '--full-run', 
        default='True', 
        type='bool', 
        help='Use MLP projection head'
    )

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    # Parse arguments
    args = parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    logger = utils.setup_logger(
        "Video_reader, classification",
        args.output_dir,
        True,
        logname=args.filename
    )

    # Run main script
    args.fold = 1
    if args.full_run:
        for i in [1, 5, 10, 20, -1]:
            args.num_shots = i
            for dataset in ['hmdb51', 'ucf101']:
                args.dataset = dataset
                main(args, logger=logger)
    else:
        main(args, logger=logger)
