from collections import defaultdict
import datetime
import numpy as np
import os
import sys
import time
import torch
from torch.utils.data.dataloader import default_collate

# Custom imports
import models
from samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
from warmup_scheduler import GradualWarmupScheduler
import transforms as T
import utils
from utils import print_or_log


# DICT with number of classes for each  dataset
NUM_CLASSES = {
    'hmdb51': 51,
    'ucf101': 101,
    'kinetics400': 400
}


# Create Finetune Model
class Finetune_Model(torch.nn.Module):
    def __init__(self, base_arch, num_ftrs=512, num_classes=101):
        super(Finetune_Model, self).__init__()
        self.base = base_arch
        self.classifier = torch.nn.Linear(num_ftrs, num_classes)
    
    def forward(self, x):
        x = self.base(x)
        x = self.classifier(x)
        return x


# Load finetune model and training params
def load_model_finetune(args, model, model_name, num_ftrs, num_classes):
    if model_name in ['av_gdt']:
        new_model = Finetune_Model(model, num_ftrs, num_classes)
        return new_model
    elif model_name in ['r2plus1d_18', 'mc3_18', 'r3d_18']:
        model.fc = models.Identity()
        new_model = Finetune_Model(model, num_ftrs, num_classes)
        return new_model


def fully_conv_model(model):
    num_classes = model.classifier.bias.shape[0]
    fc = model.classifier
    fc_conv = torch.nn.Conv3d(512, num_classes, kernel_size=(1, 1, 1))
    fc_conv.weight.data = fc.weight.data.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    fc_conv.bias.data = fc.bias.data
    model = torch.nn.Sequential(*[
        model.base.stem,
        model.base.layer1,
        model.base.layer2,
        model.base.layer3,
        model.base.layer4,
        fc_conv
    ])
    return model


## Aggerate video level softmaxes into an accuracy score
def aggregrate_video_accuracy(softmaxes, labels, topk=(1,), aggregate="mean"):
    maxk = max(topk)
    output_batch = torch.stack(
        [torch.mean(torch.stack(
            softmaxes[sms]),
            0,
            keepdim=False
        ) for sms in softmaxes.keys()])
    num_videos = output_batch.size(0)
    output_labels = torch.stack(
        [labels[video_id] for video_id in softmaxes.keys()])

    _, pred = output_batch.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(output_labels.expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / num_videos))
    return res


def evaluate(args, model, criterion, data_loader, device, logger=None, writer=None, epoch=0):

    # Num classes
    num_classes = NUM_CLASSES[args.dataset]

    # Put model in eval model
    model.eval()
    
    # dicts to store labels and softmaxes
    softmaxes = {}
    labels = {}

    metric_logger = utils.MetricLoggerFinetune(delimiter="  ")
    header = 'Test:'
    
    mode = f'{args.dataset}-{str(args.fold)}/{args.ckpt_epoch}/val'
    with torch.no_grad():
        for batch_idx, batch in metric_logger.log_every(data_loader, 10, header, logger, writer, mode, epoch=epoch, args=args):
            video, target, _, video_idx = batch
            if batch_idx == 0:
                print_or_log(video.shape, logger=logger)
            start_time = time.time()
            video = video.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            if args.use_fcn_testing:
                output = model(video).reshape(-1, num_classes, 4*7*7).mean(2)
            else:
                output = model(video)
            BS = video.shape[0]
            output = output.view(BS, -1)
            loss = criterion(output, target)

            # Clip level accuracy
            batch_size = video.shape[0]
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            metric_logger.update(loss=loss.item(), lr=0)
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
            metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))

            # Video Level accuracy
            for j in range(len(video_idx)):
                video_id = video_idx[j].item()
                sm = output[j]
                label = target[j]

                # append it to video dict
                softmaxes.setdefault(video_id, []).append(sm)
                labels[video_id] = label
    
    # Get video acc@1 and acc@5 and output to tb writer
    video_acc1, video_acc5 = aggregrate_video_accuracy(
        softmaxes, labels, topk=(1, 5)
    )
    if writer is not None:
        writer.add_scalar(
            f'{mode}/vid_acc1/epoch', 
            video_acc1, 
            epoch
        )
        writer.add_scalar(
            f'{mode}/vid_acc5/epoch', 
            video_acc5, 
            epoch
        )

    # Log final results to terminal
    print_or_log(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5), logger=logger)
    print_or_log(f' * Vid Acc@1 {video_acc1.item():.3f} Video Acc@5 {video_acc5.item():.3f}', logger=logger)
    return video_acc1.item(), video_acc5.item()


def main(args, logger, writer):

    # Set distributed mode
    device = "cuda:0" if torch.cuda.device_count() > 0 else "cpu"

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Load model
    print_or_log("Loading model", logger=logger)
    model = utils.load_model(
        model_name=args.model, 
        vid_base_arch=args.vid_base_arch, 
        aud_base_arch=args.aud_base_arch, 
        pretrained=args.pretrained,
        num_classes=256,
        norm_feat=False,
        use_mlp=args.use_mlp,
        mlptype=args.mlptype,
        headcount=args.headcount
    )

    # Add FC layer to model for fine-tuning or feature extracting
    model = load_model_finetune(
        args,
        model.video_network.base if args.model in ['av_gdt'] else model, 
        model_name=args.model, 
        num_ftrs=512 if args.vid_base_arch in ['r2plus1d_18', 'r2plus1d_34', 'r2plus1d_50'] else 2048,
        num_classes=NUM_CLASSES[args.dataset], 
    )

    # Create DataParallel model
    model_without_ddp = model
    if torch.cuda.device_count() > 1:
        model.to(device)
        print_or_log("Create model data parallel", logger=logger)
        print_or_log(f"Let's use {torch.cuda.device_count()} GPUs!")
        model = torch.nn.DataParallel(model)
        model_without_ddp = model.module

    print_or_log("Creating dataset transforms", logger=logger)
    transform_train, transform_test, subsample = utils.get_transforms(args)

    # Loading Validation data
    print_or_log("Loading validation data", logger=logger)
    if args.dataset in ['ucf101', 'hmdb51']:
        cache_path = utils._get_cache_path(args.dataset, 'val', args.fold, args.clip_len, args.steps_bet_clips)
    else:
        cache_path = utils._get_cache_path(args.dataset, 'val', 1, args.clip_len, args.steps_bet_clips)
    if args.cache_dataset and os.path.exists(cache_path):
        print_or_log(f"Loading dataset_test from {cache_path}", logger=logger)
        dataset_test = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        dataset_test = utils.load_dataset(
            dataset_name=args.dataset,
            fold=args.fold,
            mode='val',
            frames_per_clip=args.clip_len,
            transforms=transform_test,
            subsample=subsample
        )
        if args.cache_dataset:
            print_or_log(f"Saving dataset_test to {cache_path}", logger=logger)
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test), cache_path)

    print_or_log("Creating data samplers", logger=logger)
    test_sampler = UniformClipSampler(dataset_test.video_clips, args.val_clips_per_video)

    print_or_log("Creating data loaders", logger=logger)
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, 
        batch_size=args.val_batch_size,
        sampler=test_sampler, 
        num_workers=args.workers,
        pin_memory=True, 
        drop_last=True
    )

    # Set up training params
    criterion = torch.nn.CrossEntropyLoss()

    # Checkpointing
    checkpoint = torch.load(args.weights_path, map_location='cpu')
    model_without_ddp.load_state_dict(checkpoint['model'])
    args.start_epoch = checkpoint['epoch']
    print_or_log(f"Eval @ {args.start_epoch}", logger=logger)

    # FCN model
    if args.use_fcn_testing:
        model = fully_conv_model(model_without_ddp)
        model = torch.nn.DataParallel(model)

    vid_acc1, vid_acc5 = evaluate(
        args,
        model, 
        criterion, 
        data_loader_test, 
        device=device,
        logger=logger,
        writer=writer,
        epoch=args.start_epoch
    )
    return vid_acc1, vid_acc5


def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
            'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='Video Representation Learning')
    parser.register('type', 'bool', str2bool)

    ### DATA
    parser.add_argument(
        '--dataset', 
        default='kinetics400', 
        help='name of dataset'
    )
    parser.add_argument(
        '--fold', 
        default='1', 
        type=str,
        help='name of dataset'
    )
    parser.add_argument(
        '-b', 
        '--val-batch-size', 
        default=32, 
        type=int
    )
    parser.add_argument(
        '--clip-len', 
        default=32, 
        type=int, 
        metavar='N',
        help='number of frames per clip'
    )
    parser.add_argument(
        '--augtype',
        default=1,
        type=int, 
        help='augmentation type (default: 1)'
    )
    parser.add_argument(
        '--colorjitter',
        default='False',
        type='bool', 
        help='scale jittering as augmentations'
    )
    parser.add_argument(
        '--steps-bet-clips', 
        default=1, 
        type=int, 
        metavar='N',
        help='number of steps between clips in video'
    )
    parser.add_argument(
        '--num-data-samples', 
        default=None, 
        type=int, 
        help='number of samples in dataset'
    )
    parser.add_argument(
        '--val-clips-per-video', 
        default=10, 
        type=int, 
        metavar='N',
        help='maximum number of clips per video to consider for testing'
    )
    parser.add_argument(
        "--cache-dataset",
        type='bool', 
        default='True',
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
    )
    parser.add_argument(
        '-j', '--workers', 
        default=16, 
        type=int, 
        metavar='N',
        help='number of data loading workers (default: 16)'
    )
    parser.add_argument(
        '--use-scale-jittering',
        default='True',
        type='bool', 
        help='scale jittering as augmentations'
    )

    ### MODEL
    parser.add_argument(
        '--model', 
        default='av_gdt', 
        help='model',
        choices=['r2plus1d_18', 'av_gdt']
    )
    parser.add_argument(
        '--weights-path', 
        default='kinetics_finetune_ig65m/fin-kinetics400-52-jitter-False-clips-5/checkpoints/checkpoint.pth', 
        help='Path to weights file',
    )
    parser.add_argument(
        '--ckpt-epoch', 
        default='0',
        type=str,
        help='Epoch of model checkpoint',
    )
    parser.add_argument(
        '--vid-base-arch', 
        default='r2plus1d_18', 
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 's3d', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch', 
        default='resnet18', 
        help='Audio Base Arch for A-V model',
        choices=['resnet18', 'vgg_audio', 'resnet34', 'resnet50', 'resnet9']
    )
    parser.add_argument(
        "--pretrained",
        type='bool', 
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        '--use-mlp', 
        default='True', 
        type='bool', 
        help='Use MLP projection head'
    )
    parser.add_argument(
        '--mlptype',
        default=0,
        type=int,
        help='MLP type (default: 0)'
    )
    parser.add_argument(
        '--headcount',
        type=int,
        default=1,
        help='how many heads each modality has'
    )
    parser.add_argument(
        '--use-fcn-testing', 
        default='False', 
        type='bool', 
        help='Use FCN testing'
    )   
    
    ### LOGGING
    parser.add_argument(
        '--print-freq', 
        default=10, 
        type=int, 
        help='print frequency'
    )
    parser.add_argument(
        '--output-dir', 
        default='.', 
        help='path where to save'
    )
    parser.add_argument(
        '--global-rank', 
        default=0, 
        type=int, 
        help='global rank of process'
    )

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    print(args)
    
    # set multi-processing start method
    import torch.multiprocessing as mp
    try:
        mp.set_start_method('forkserver')
    except RuntimeError:
        pass

    # Make output dir
    tbx_path = os.path.join(args.output_dir, 'tensorboard')
    if args.output_dir:
        utils.mkdir(args.output_dir)

    # Set up logger
    filename = 'logger.out'
    logger = utils.setup_logger(
        "Video_reader, classification",
        args.output_dir,
        True,
        logname=filename
    )

    # Set up tensorboard
    writer = utils.setup_tbx(
        tbx_path,
        True
    )
    writer.add_text("namespace", repr(args))
    
    # Log version information
    print_or_log(args, logger=logger)
    print_or_log(f"torch version: {torch.__version__}", logger=logger)
    
    # Run over different folds
    best_accs_1 = []
    best_accs_5 = []
    best_epochs = []
    folds = [int(fold) for fold in args.fold.split(',')]
    print(folds)
    if args.dataset in ['ucf101', 'hmdb51']:
        for fold in folds:
            args.fold = fold
            best_acc1, best_acc5 = main(args, logger, writer)
            best_accs_1.append(best_acc1)
            best_accs_5.append(best_acc5)
        print(best_accs_1)
        print(best_acc5)
        avg_acc1 = np.mean(best_accs_1)
        avg_acc5 = np.mean(best_accs_5)
        print(f'3-Fold ({args.dataset}): Vid Acc@1 {avg_acc1:.3f}, Video Acc@5 {avg_acc5:.3f}')
    else:
        best_acc1, best_acc5 = main(args, logger, writer)
        best_accs_1.append(best_acc1)
        best_accs_5.append(best_acc5)
        avg_acc1 = np.mean(best_accs_1)
        avg_acc5 = np.mean(best_accs_5)
        print(f'3-Fold ({args.dataset}): Vid Acc@1 {avg_acc1:.3f}, Video Acc@5 {avg_acc5:.3f}')
