import datetime
import numpy as np
import os

import time
import torch
from torch.utils.data.dataloader import default_collate
import torchvision
from scipy import stats
# Custom imports
from datasets.AudioDataset import GetAudioDataset

import utils
from utils import print_or_log
from sklearn import metrics

# DICT with number of classes for each  dataset
NUM_CLASSES = {
    'vggsound': 309,
}


def accuracy(output, target, topk=(1, 5)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res, pred


def d_prime(auc):
    standard_normal = stats.norm()
    d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
    return d_prime


def calculate_stats(output, target):
    """Calculate statistics including mAP, AUC, etc.

    Args:
      output: 2d array, (samples_num, classes_num)
      target: 2d array, (samples_num, classes_num)

    Returns:
      stats: list of statistic of each class.
    """

    classes_num = target.shape[-1]
    stats = []

    # Class-wise statistics
    for k in range(classes_num):
        # Average precision
        avg_precision = metrics.average_precision_score(
            target[:, k], output[:, k], average=None)

        # AUC
        auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)

        # Precisions, recalls
        (precisions, recalls, thresholds) = metrics.precision_recall_curve(
            target[:, k], output[:, k])

        # FPR, TPR
        (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])

        save_every_steps = 1000  # Sample statistics to reduce size
        dict = {'precisions': precisions[0::save_every_steps],
                'recalls': recalls[0::save_every_steps],
                'AP': avg_precision,
                'fpr': fpr[0::save_every_steps],
                'fnr': 1. - tpr[0::save_every_steps],
                'auc': auc}
        stats.append(dict)

    return stats

def aggregrate_audio_accuracy(softmaxes, labels, topk=(1,), writer=None, epoch=0, str=''):
    maxk = max(topk)
    output_batch = torch.stack(
        [torch.mean(torch.stack(
            softmaxes[sms]),
            0,
            keepdim=False
        ) for sms in softmaxes.keys()])
    num_videos = output_batch.size(0)
    output_labels = torch.stack(
        [labels[video_id] for video_id in softmaxes.keys()])

    _, pred = output_batch.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(output_labels.expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / num_videos))
    num_classes = len(torch.unique(output_labels))
    y = torch.eye(num_classes)
    one_hot = y[output_labels]
    print(output_batch.shape, one_hot.shape, flush=True)
    stats = calculate_stats(output_batch.cpu(), one_hot.cpu())
    mAP = np.mean([stat['AP'] for stat in stats])
    mAUC = np.mean([stat['auc'] for stat in stats])

    print("mAP: {:.6f}".format(mAP))
    print("mAUC: {:.6f}".format(mAUC))
    print("dprime: {:.6f}".format(d_prime(mAUC)))
    writer.add_scalar(str + 'mAP', mAP, epoch)
    writer.add_scalar(str + 'mAUC', mAUC, epoch)
    writer.add_scalar(str + 'dprime', d_prime(mAUC), epoch)
    return res

class Finetune_Model(torch.nn.Module):
    def __init__(self, base_arch, num_ftrs=512, num_classes=101):
        super(Finetune_Model, self).__init__()
        self.base = base_arch

        self.classifier = torch.nn.Linear(num_ftrs, num_classes)
    def forward(self, x):
        x = self.base(x)
        x = self.classifier(x)
        return x


# Load finetune model and training params
def load_model_finetune(args, model, num_ftrs, num_classes):
    new_model = Finetune_Model(model, num_ftrs, num_classes)
    return new_model


def get_dataloader(args, mode='train'):
    dataset = GetAudioDataset(csvpath='./data/vggsound/metadata/',
                              datapath=args.datapath, mode=mode)

    print(f"LEN DATASET-{mode}: {len(dataset)}", flush=True)
    sampler = None
    if args.distributed and mode == 'train' and args.world_size > 8:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size if mode == 'train' else args.val_batch_size,
        sampler=sampler,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=False,
        shuffle=True if sampler is None else False
    )
    print(f"LEN LOADER-{mode}: {len(data_loader)}", flush=True)
    return dataset, data_loader, sampler


def train_one_epoch(
        args,
        data_loader,
        model,
        criterion,
        optimizer,
        device,
        epoch,
        print_freq,
        logger=None,
        writer=None,
):
    model.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}'))

    header = 'Epoch: [{}]'.format(epoch)
    for batch_idx, batch in metric_logger.log_every(data_loader, print_freq, header, logger, writer, 'train',
                                                    epoch=epoch, args=args):
        # audio, target, _, _ = batch
        audio, target, vid_idx = batch
        if batch_idx == 0:
            print_or_log(audio.shape, logger=logger)
        start_time = time.time()
        audio = audio.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # Forward pass: get features, compute loss and accuracy
        output = model(audio)
        loss = criterion(output, target)
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

        # initialize the optimizers
        optimizer.zero_grad()

        # compute the gradients
        loss.backward()

        # step
        optimizer.step()

        batch_size = audio.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
        metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))


def evaluate(args, data_loader, model, criterion, device, logger=None, writer=None, epoch=0):
    # Put model in eval model
    model.eval()

    # dicts to store labels and softmaxes
    softmaxes = {}
    labels = {}

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    torch.cuda.empty_cache()
    with torch.no_grad():
        for batch_idx, batch in metric_logger.log_every(data_loader, 100, header, logger, writer, 'val', epoch=epoch,
                                                        args=args):
            audio, target, vid_idx = batch
            mass = audio.size(0)

            if batch_idx == 0:
                print_or_log((batch_idx, len(data_loader), audio.shape), logger=logger)
            start_time = time.time()
            audio = audio.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            output = torch.nn.functional.softmax(model(audio), dim=1)
            BS = audio.shape[0]
            output = output.view(BS, -1)
            loss = criterion(output, target)

            # Clip level accuracy
            batch_size = audio.shape[0]
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            metric_logger.update(loss=loss.item(), lr=0)
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
            metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))

            # Video Level accuracy
            for j in range(len(vid_idx)):
                video_id = vid_idx[j].item()
                sm = output[j]
                label = target[j]

                # append it to video dict
                softmaxes.setdefault(video_id, []).append(sm)
                labels[video_id] = label

    # Get video acc@1 and acc@5 and output to tb writer
    audio_acc1, audio_acc5 = aggregrate_audio_accuracy(
        softmaxes, labels, topk=(1, 5), writer=writer, epoch=epoch, str=f'{args.dataset}/{args.ckpt_epoch}/val/'
    )
    # if int(os.environ['SLURM_PROCID']) == 0:
    writer.add_scalar(
        f'{args.dataset}/{args.ckpt_epoch}/val/aud_acc1/epoch',
        audio_acc1,
        epoch
    )
    writer.add_scalar(
        f'{args.dataset}/{args.ckpt_epoch}/val/aud_acc5/epoch',
        audio_acc5,
        epoch
    )

    # Log final results to terminal
    print_or_log(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}'
                 .format(top1=metric_logger.acc1, top5=metric_logger.acc5), logger=logger)
    print_or_log(f' * Aud Acc@1 {audio_acc1.item():.3f} Aud Acc@5 {audio_acc5.item():.3f}', logger=logger)
    return audio_acc1.item(), audio_acc5.item()


def main(args, logger, writer):
    # Set distributed mode
    device = torch.device(args.device)

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Load model
    print_or_log("Loading model", logger=logger)
    model = utils.load_model(
        model_name=args.model,
        vid_base_arch=args.vid_base_arch,
        aud_base_arch=args.aud_base_arch,
        pretrained=args.pretrained,
        num_classes=256,
        norm_feat=False,
        use_mlp=args.use_mlp,
        mlptype=args.mlptype,
        headcount=args.headcount,
        use_max_pool=False,
    )

    # Load model weights
    start = time.time()
    weight_path_type = type(args.weights_path)
    weight_path_not_none = args.weights_path != 'None' if weight_path_type == str else args.weights_path is not None
    if weight_path_not_none:
        print_or_log("Loading model weights", logger=logger)
        if os.path.exists(args.weights_path):
            ckpt_dict = torch.load(args.weights_path)
            model_weights = ckpt_dict["model"]
            args.ckpt_epoch = ckpt_dict['epoch']
            print(f"Epoch checkpoint: {args.ckpt_epoch}", flush=True)
            utils.load_model_parameters(model, model_weights)
    print_or_log(f"Time to load model weights: {time.time() - start}", logger=logger)

    # Scale lr and batch-size
    if args.use_scaling:
        print_or_log(f"Using scaling: {args.world_size}", logger=logger)
        if args.optim_name in ['sgd']:
            args.head_lr = args.head_lr * args.world_size
    print_or_log(f"Head_lr: {args.head_lr}", logger=logger)
    print_or_log(f"Train BS: {args.batch_size}", logger=logger)
    print_or_log(f"Val BS: {args.val_batch_size}", logger=logger)

    # Add FC layer to model for fine-tuning or feature extracting
    model = load_model_finetune(
        args,
        model.audio_network.base,
        num_ftrs=512,
        num_classes=NUM_CLASSES[args.dataset],
    )

    # Create DataParallel model
    model_without_ddp = model
    if torch.cuda.device_count() >= 1:
        model.to(device)
        if args.distributed and args.world_size > 8:
            print_or_log("Create model DDP", logger=logger)
            if args.sync_bn:
                print_or_log("Using SYNC BN", logger=logger)
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
                broadcast_buffers=False
            )
        else:
            model.to(device)
            print_or_log("Create model data parallel", logger=logger)
            print_or_log(f"Let's use {torch.cuda.device_count()} GPUs!")
            model = torch.nn.DataParallel(model)
        model_without_ddp = model.module

    # get lr params
    params = []
    if args.feature_extract:  # feature_extract only classifer
        print_or_log("Getting params for feature-extracting", logger=logger)
        for name, param in model_without_ddp.classifier.named_parameters():
            params.append({'params': param, 'lr': args.head_lr})
    else:  # finetune
        print_or_log("Getting params for finetuning", logger=logger)
        print(f"Head LR params: {args.head_lr}")
        for name, param in model_without_ddp.classifier.named_parameters():
            params.append({'params': param, 'lr': args.head_lr})
        print(f"Base LR params: {args.base_lr}")
        for name, param in model_without_ddp.base.named_parameters():
            params.append({'params': param, 'lr': args.base_lr})

    # Set up training params
    criterion = torch.nn.CrossEntropyLoss()

    # linearly scale LR and set up optimizer
    print_or_log(
        f"Using {args.optim_name} optimization with lr: {args.head_lr}, mom: {args.momentum}, wd: {args.weight_decay}",
        logger=logger)
    optimizer = utils.load_optimizer(
        args.optim_name,
        params,
        lr=args.head_lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # Multi-step LR scheduler
    if args.use_scheduler:
        milestones = [int(lr) - 0 for lr in args.lr_milestones.split(',')]
        print_or_log(f"Num. of Epochs: {args.epochs}, Milestones: {milestones}", logger=logger)
        print_or_log("Using LR multi-step scheduler w/out warmp", logger=logger)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=milestones,
            gamma=args.lr_gamma
        )
    else:
        lr_scheduler = None

    # Checkpointing
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        args.start_epoch = checkpoint['epoch'] + 1
        if args.use_scheduler:
            for k in range(args.start_epoch):
                lr_scheduler.step()
            # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            print(lr_scheduler.state_dict(), flush=True)

    # Load dataloaders
    ds, data_loader, train_sampler = get_dataloader(args, mode='train')
    del ds
    ds, data_loader_test, _ = get_dataloader(args, mode='val')
    del ds

    # Only perform evalaution
    if args.test_only:
        evaluate(
            args,
            data_loader_test,
            model,
            criterion,
            device=device,
            logger=logger,
            writer=writer,
            epoch=args.start_epoch
        )
        return

    start_time = time.time()
    best_aud_acc_1 = -1
    best_aud_acc_5 = -1
    best_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        print_or_log(f'Start training epoch: {epoch}', logger=logger)
        if args.distributed and args.world_size > 8:
            train_sampler.set_epoch(epoch)
        train_one_epoch(
            args,
            data_loader,
            model,
            criterion,
            optimizer,
            device,
            epoch,
            args.print_freq,
            logger=logger,
            writer=writer,
        )
        print_or_log(f'Start evaluating epoch: {epoch}', logger=logger)
        if epoch % 1 == 0:
            aud_acc1, aud_acc5 = evaluate(
                args,
                data_loader_test,
                model,
                criterion,
                device=device,
                logger=logger,
                writer=writer,
                epoch=epoch
            )
        if args.use_scheduler:
            lr_scheduler.step()
        if aud_acc1 > best_aud_acc_1:
            best_aud_acc_1 = aud_acc1
            best_aud_acc_5 = aud_acc5
            best_epoch = epoch
        print_or_log(f'Saving checkpoint to: {args.output_dir}', logger=logger)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print_or_log(f'Training time {total_time_str}', logger=logger)
    return best_aud_acc_1, best_aud_acc_5, best_epoch


def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
                         'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='Audio Finetuning')
    parser.register('type', 'bool', str2bool)
    parser.add_argument(
        '--datasettype',
        default='audio',
        type=str,
        help='audio for using vggsound audio dataset'
    )
    parser.add_argument(
        '--datapath',
        default='./data/vggsound-audio/',
        type=str,
        help='audio for using vggsound audio dataset'
    )
    # AUDIO UTILS
    parser.add_argument(
        '--aud-sample-rate',
        default=24000,
        type=int,
        help='audio sample rate'
    )
    parser.add_argument(
        '--aud-spec-type',
        default=2,
        type=int,
        help='audio spec type'
    )
    parser.add_argument(
        '--use-volume-jittering',
        default='False',
        type='bool',
        help='use volume jittering'
    )
    parser.add_argument(
        '--use-temporal-jittering',
        default='False',
        type='bool',
        help='use temporal jittering'
    )
    parser.add_argument(
        '--num-sec',
        default=3,
        type=int,
        help='Number of seconds'
    )
    parser.add_argument(
        '--aug-audio',
        default='False',
        type='bool',
        help='whether to augment audio'
    )
    parser.add_argument(
        '--z-normalize',
        default='True',
        type='bool',
        help='normalize audio'
    )

    ### DATA
    parser.add_argument(
        '--dataset',
        default='vggsound',
        help='name of dataset'
    )
    parser.add_argument(
        '--num-data-samples',
        default=None,
        type=int,
        help='number of samples in dataset'
    )
    parser.add_argument(
        '--train-clips-per-video',
        default=1,
        type=int,
        help='maximum number of clips per video to consider for training'
    )
    parser.add_argument(
        '--val-clips-per-video',
        default=1,
        type=int,
        help='maximum number of clips per video to consider for testing'
    )
    parser.add_argument(
        '--workers',
        default=0,
        type=int,
        help='number of data loading workers (default: 16)'
    )

    ### MODEL
    parser.add_argument(
        '--model',
        default='avc',
        help='model',
        choices=['avc']
    )
    parser.add_argument(
        '--weights-path',
        default='ig_model_weights/model_with_corrected_mlp_weights_188790307_e52.pt',
        help='Path to weights file',
    )
    parser.add_argument(
        '--ckpt-epoch',
        default=0,
        type=int,
        help='Epoch of model checkpoint',
    )
    parser.add_argument(
        '--vid-base-arch',
        default='r2plus1d_18',
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 's3d', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch',
        default='resnet9',
        help='Audio Base Arch for A-V model',
        choices=['resnet18', 'resnet34', 'resnet9']
    )
    parser.add_argument(
        "--pretrained",
        type='bool',
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        '--use-mlp',
        default='True',
        type='bool',
        help='Use MLP projection head'
    )
    parser.add_argument(
        "--sync-bn",
        type='bool',
        default='True',
        help="Use sync batch norm",
    )

    ### FINETUNE params
    parser.add_argument(
        "--feature-extract",
        type='bool',
        default='False',
        help="Use model as feature extractor; if False, fientune entire model",
    )

    ### TRAINING
    parser.add_argument(
        '-b', '--batch-size',
        default=128,
        type=int
    )
    parser.add_argument(
        '--val-batch-size',
        default=128,
        type=int
    )
    parser.add_argument(
        '--epochs',
        default=30,
        type=int,
        metavar='N',
        help='number of total epochs to run'
    )
    parser.add_argument(
        "--use-scaling",
        type='bool',
        default='False',
        help="Use LR scaling",
    )
    parser.add_argument(
        '--optim-name',
        default='adam',
        type=str,
        help='Name of optimizer',
        choices=['sgd', 'adam']
    )
    parser.add_argument(
        '--head-lr',
        default=1e-3,
        type=float,
        help='initial learning rate'
    )
    parser.add_argument(
        '--base-lr',
        default=1e-4,
        type=float,
        help='initial learning rate'
    )
    parser.add_argument(
        '--momentum',
        default=0.9,
        type=float,
        metavar='M',
        help='momentum'
    )
    parser.add_argument(
        '--weight-decay',
        default=1e-5,
        type=float,
        metavar='W',
        help='weight decay (default: 1e-4)',
        dest='weight_decay'
    )
    parser.add_argument(
        "--use-scheduler",
        type='bool',
        default='True',
        help="Use LR scheduler",
    )

    parser.add_argument(
        '--lr-milestones',
        default='10,20',
        type=str,
        help='decrease lr on milestones (epochs)'
    )
    parser.add_argument(
        '--lr-gamma',
        default=0.1,
        type=float,
        help='decrease lr by a factor of lr-gamma'
    )

    ### LOGGING
    parser.add_argument(
        '--print-freq',
        default=10,
        type=int,
        help='print frequency'
    )
    parser.add_argument(
        '--output-dir',
        default='.',
        help='path where to save'
    )

    ### CHECKPOINTING
    parser.add_argument(
        '--resume',
        default='',
        help='resume from checkpoint'
    )
    parser.add_argument(
        '--start-epoch',
        default=0, type=int,
        metavar='N',
        help='start epoch'
    )
    parser.add_argument(
        "--test-only",
        type='bool',
        default='False',
        help="Only test the model",
    )

    # distributed training parameters
    parser.add_argument(
        '--device',
        default='cuda',
        help='device'
    )
    parser.add_argument(
        '--distributed',
        type='bool',
        default='False',
        help="ddp mode",
    )
    parser.add_argument(
        '--dist-backend',
        default='nccl',
        type=str,
        help='distributed backend'
    )
    parser.add_argument(
        '--dist-url',
        default='env://',
        help='url used to set up distributed training'
    )
    parser.add_argument(
        '--world-size',
        default=1,
        type=int,
        help='number of distributed processes'
    )
    parser.add_argument(
        '--debug-slurm',
        type='bool',
        default='False',
        help="Debug SLURM",
    )
    parser.add_argument(
        '--local-rank',
        default=-1,
        type=int,
        help='Local rank of node')
    parser.add_argument(
        '--master-port',
        default=-1,
        type=int,
        help='Master port of Job'
    )

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    # set multi-processing start method
    import torch.multiprocessing as mp

    try:
        mp.set_start_method('forkserver')
    except RuntimeError:
        pass

    # Init distributed mode
    if torch.cuda.is_available():
        utils.init_distributed_mode(args)

    # Set up logger
    print(f'Distributed Mode: {args.distributed}')

    # Make output dir
    tbx_path = os.path.join(args.output_dir, 'tensorboard')
    if args.output_dir:
        utils.mkdir(args.output_dir)

    filename = 'logger.out'
    logger = utils.setup_logger(
        "VGGSound, finetune",
        args.output_dir,
        True,
        logname=filename
    )

    # Set up tensorboard
    global_rank = int(os.environ['SLURM_PROCID']) if args.distributed else 0
    is_master = True if global_rank == 0 else False
    writer = None
    if is_master:
        writer = utils.setup_tbx(
            tbx_path,
            is_master
        )
        writer.add_text("namespace", repr(args))

    # Log version information
    print_or_log(args, logger=logger)
    print_or_log(f"torch version: {torch.__version__}", logger=logger)
    print_or_log(f"torchvision version: {torchvision.__version__}", logger=logger)

    # Run over different folds
    args.fold = '1'
    best_acc1, best_acc5, best_epoch = main(args, logger, writer)
    print(f'Aud Acc@1 {best_acc1:.3f}, Aud Acc@5 {best_acc5:.3f}')