import datetime
from datasets.AudioDataset import ESC_DCASE
import numpy as np
import os

import time
import torch
import torchvision
import math

# Custom imports
import utils
from utils import print_or_log
import random

SEED = random.randint(0, 50000)
torch.manual_seed(SEED)

DATASET_NUM_CLASSES = {
    'esc50': 50,
    'dcase2014': 10,
}

DATASET_NUM_CLIPS = {
    'esc50': 25,
    'dcase2014': 60,
}

class FinetuneModelLinear(torch.nn.Module):
    """model wrapper such that only top linear layer is trained"""
    def __init__(self, num_ftrs=512, num_classes=50):
        super(FinetuneModelLinear, self).__init__()
        self.lin = torch.nn.Linear(num_ftrs, num_classes)

        self._init()

    def forward(self, x):
        return self.lin(x)

    def _init(self):
        stdv = 1. / math.sqrt(self.lin.weight.size(1))
        self.lin.weight.data.uniform_(-stdv, stdv)
        if self.lin.bias is not None:
            self.lin.bias.data.uniform_(-stdv, stdv)


def get_embeddings_batch(
        dataset='esc50',
        model=None,
        mode='train',
        val_fold=1,
        batch_size=128,
        use_cuda=False,
):
    """ get activations of data split and return them as single object"""
    # Put model in eval mode
    model.eval()
    # Load dataset
    start = time.time()
    mode_dataset = ESC_DCASE(
        val_fold=val_fold,
        mode=mode,
        seconds=args.seconds,
        num_samples=DATASET_NUM_CLIPS[dataset]
    )
    print(f"Time to load Dataset: {time.time() - start}", flush=True)

    # Create Dataloader
    data_loader = torch.utils.data.DataLoader(
        mode_dataset,
        batch_size=batch_size,
        num_workers=10,
        collate_fn=None,
    )

    # Iterate through dataset and save features and weights
    result = {}
    result['features'] = None
    result['labels'] = None
    features = []
    labels = []
    indices = []
    start = time.time()
    print(f"Getting features and labels for {mode}", flush=True)
    inps = []
    print('here')
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            inp, label, aud_idx = data
            inps.append(inp)
            if use_cuda:
                inp = inp.cuda()
                label = label.cuda()
                aud_idx = aud_idx.cuda()
            feat = model(inp)
            if i == 0:
                print("inputs shape=", inp.shape)
                print("features shape=", feat.shape)
                print()
            # features.append(feat.detach().cpu().numpy())
            # labels.append(label.detach().cpu().numpy())
            # indices.append(aud_idx.detach().cpu().numpy())
            features.append(feat.detach())
            labels.append(label.detach())
            indices.append(aud_idx.detach())
            print(f'{i} / {len(data_loader)}', end='\r')

        print(f"Time to create features and labels obj: {time.time() - start}", flush=True)
        # Create dict of np arrays
        result['features'] = torch.cat(features, 0)
        result['labels'] = torch.cat(labels, 0)
        result['indices'] = torch.cat(indices, 0)
    return result


def aggregrate_audio_accuracy(softmaxes, labels, topk=(1,)):
    """Aggerate audio level softmaxes into an accuracy score """
    maxk = max(topk)
    output_batch = torch.stack(
        [torch.mean(torch.stack(
            softmaxes[sms]),
            0,
            keepdim=False
        ) for sms in softmaxes.keys()])
    num_clips = output_batch.size(0)
    output_labels = torch.stack(
        [labels[audio_id] for audio_id in softmaxes.keys()])

    _, pred = output_batch.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(output_labels.expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / num_clips))
    return res


def train_one_epoch(
        args,
        model,
        criterion,
        optimizer,
        epoch,
        print_freq,
        logger=None,
        writer=None,
        features=None,
):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}'))

    header = 'Epoch: [{}]'.format(epoch)
    permutation = np.random.permutation(range(len(features['features'])))  # .reshape(-1,args.batch_size)
    perm = utils.PermuteIter(permutation, args.batch_size)
    for _, batch_idx in metric_logger.log_every(perm, print_freq, header, logger, writer, 'train', epoch=epoch,
                                                args=args):
        audio = features['features'][batch_idx, :]
        target = features['labels'][batch_idx]
        start_time = time.time()
        # Forward pass: get features, compute loss and accuracy
        output = model(audio)
        loss = criterion(output, target)
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        # initialize the optimizers
        optimizer.zero_grad()
        # compute the gradients
        loss.backward()
        # step
        optimizer.step()

        batch_size = audio.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
        metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))


def evaluate(args, model, criterion, logger=None, writer=None, epoch=0, features=None):
    # Put model in eval model
    model.eval()

    # dicts to store labels and softmaxes
    softmaxes = {}
    labels = {}

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    torch.cuda.empty_cache()
    with torch.no_grad():
        permutation = np.random.permutation(range(len(features['features'])))  # .reshape(-1,100)
        perm = utils.PermuteIter(permutation, 100)
        for _, batch_idx in metric_logger.log_every(perm, 100, header, logger, writer, 'train', epoch=epoch,
                                                    args=args):
            start_time = time.time()
            audio = features['features'][batch_idx]
            target = features['labels'][batch_idx]
            aud_idx = features['indices'][batch_idx]
            output = model(audio)
            BS = audio.shape[0]
            output = torch.nn.functional.softmax(output.view(BS, -1), dim=1)

            loss = criterion(output, target)

            # Clip level accuracy
            batch_size = audio.shape[0]
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            metric_logger.update(loss=loss.item(), lr=0)
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
            metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))

            # Audio Level accuracy
            for j in range(len(aud_idx)):
                audio_id = aud_idx[j].item()
                sm = output[j]
                label = target[j]

                # append it to audio dict
                softmaxes.setdefault(audio_id, []).append(sm)
                labels[audio_id] = label
    # Get audio acc@1 and acc@5 and output to tb writer
    audio_acc1, _ = aggregrate_audio_accuracy(
        softmaxes, labels, topk=(1, 5)
    )
    return audio_acc1.item()


def main(args, logger, writer):
    # Set distributed mode
    device = torch.device(args.device)

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Load model
    print_or_log("Loading model", logger=logger)
    model = utils.load_model(
        model_name=args.model,
        vid_base_arch=args.vid_base_arch,
        aud_base_arch=args.aud_base_arch,
        pretrained=args.pretrained,
        norm_feat=False,
        use_mlp=args.use_mlp
    )

    # Load model weights
    start = time.time()
    weight_path_type = type(args.weights_path)
    weight_path_not_none = args.weights_path != 'None' if weight_path_type == str else args.weights_path is not None
    if weight_path_not_none:
        print_or_log("Loading model weights", logger=logger)
        if os.path.exists(args.weights_path):
            ckpt_dict = torch.load(args.weights_path)
            model_weights = ckpt_dict["model"]
            args.ckpt_epoch = ckpt_dict['epoch']
            print(f"Epoch checkpoint: {args.ckpt_epoch}", flush=True)
            utils.load_model_parameters(model, model_weights)
    print_or_log(f"Time to load model weights: {time.time() - start}", logger=logger)

    # Scale lr and batch-size
    val_batch_size = 24
    train_batch_size = args.batch_size
    val_batch_size = val_batch_size
    print_or_log(f"Head_lr: {args.head_lr}", logger=logger)
    print_or_log(f"Train BS: {train_batch_size}", logger=logger)
    print_or_log(f"Val BS: {val_batch_size}", logger=logger)

    # Add FC layer to model for fine-tuning or feature extracting
    # Get outputs from right layer
    # with 257x197 and r9 output is 9 x 7

    model = model.audio_network.base
    pool = torch.nn.Identity()

    if args.pooltype == 'pool22-22-nopad':
        pool = torch.nn.MaxPool2d((2, 2), stride=(2, 2), padding=(0, 0))
    elif args.pooltype == 'gap':
        pool = torch.nn.AdaptiveAvgPool2d((1, 1))
    elif args.pooltype == 'id':
        pool = torch.nn.Identity()

    model.layer4.relu = torch.nn.Identity()
    model = torch.nn.Sequential(*[
        model.conv1,
        model.bn1,
        model.relu,
        model.maxpool,
        model.layer1,
        model.layer2,
        model.layer3,
        model.layer4,
        pool,
        utils.Flatten(),
    ])
    print_or_log("Feature extractor", logger=logger)
    if fold == 1:
        print(model)

    # Add model to GPU
    model.to(device)

    # Loading Train data
    # Load dataset
    start = time.time()
    dataset = ESC_DCASE(
        val_fold=args.fold,
        mode='train',
        seconds=args.seconds,
        num_samples=DATASET_NUM_CLIPS[args.dataset],
        nfilter=args.nfilter
    )
    dataset_test = ESC_DCASE(
        val_fold=args.fold,
        mode='val',
        seconds=args.seconds,
        num_samples=DATASET_NUM_CLIPS[args.dataset],
        nfilter=args.nfilter
    )
    print(f"Time to load Dataset {args.dataset}: {time.time() - start}", flush=True)

    print_or_log("Creating data loaders", logger=logger)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=val_batch_size,
        num_workers=args.workers,
        pin_memory=True,
    )

    # Set up training params
    criterion = torch.nn.CrossEntropyLoss()

    # Only perform evalaution
    if args.test_only:
        evaluate(
            args,
            model,
            criterion,
            logger=logger,
            writer=writer,
            epoch=args.start_epoch
        )
        return

    features = get_embeddings_batch(
        dataset=args.dataset,
        model=model,
        mode='train',
        val_fold=fold,
        batch_size=128,
        use_cuda=True)
    features_val = get_embeddings_batch(
        dataset=args.dataset,
        model=model,
        mode='val',
        val_fold=fold,
        batch_size=128,
        use_cuda=True)

    start_time = time.time()
    best_vid_acc_1 = -1
    best_epoch = 0
    # no need to keep big model
    # get lr params
    params = []
    num_ftrs = features_val['features'].shape[1]
    model = FinetuneModelLinear(num_ftrs, DATASET_NUM_CLASSES[args.dataset])

    if args.feature_extract:  # feature_extract only classifer
        print_or_log("Getting params for feature-extracting", logger=logger)
        for name, param in model.classifier.named_parameters():
            params.append({'params': param, 'lr': args.head_lr})
    else:  # finetune
        print_or_log("Getting params for finetuning", logger=logger)
        print(f"Head LR params: {args.head_lr}")
        for name, param in model.classifier.named_parameters():
            params.append({'params': param, 'lr': args.head_lr})
        print(f"Base LR params: {args.base_lr}")
        for name, param in model.base.named_parameters():
            params.append({'params': param, 'lr': args.base_lr})

    model = model.to(device)
    print_or_log('=================================\n', logger=logger)
    # linearly scale LR and set up optimizer
    print_or_log(
        f"Using {args.optim_name} optimization with lr: {args.head_lr}, mom: {args.momentum}, wd: {args.weight_decay}",
        logger=logger)

    optimizer = utils.load_optimizer(
        args.optim_name,
        params,
        lr=args.head_lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        model=model
    )

    # Multi-step LR scheduler
    if args.use_scheduler:
        milestones = [int(lr) for lr in args.lr_milestones.split(',')]
        print_or_log(f"Num. of Epochs: {args.epochs}, Milestones: {milestones}", logger=logger)
        print_or_log("Using LR multi-step scheduler w/out warmp", logger=logger)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=milestones,
            gamma=args.lr_gamma
        )
    print('starting!', flush=True)
    for epoch in range(args.start_epoch, args.epochs):
        # print_or_log(f'Start training epoch: {epoch}', logger=logger)
        train_one_epoch(
            args,
            model,
            criterion,
            optimizer,
            epoch,
            args.print_freq,
            features=features,
            logger=logger,
            writer=writer,
        )
        # print_or_log(f'Start evaluating epoch: {epoch}', logger=logger)
        if args.use_scheduler:
            lr_scheduler.step()

        vid_acc1 = evaluate(
            args,
            model,
            criterion,
            logger=logger,
            writer=writer,
            epoch=epoch,
            features=features_val
        )
        if vid_acc1 > best_vid_acc_1:
            best_vid_acc_1 = vid_acc1
            best_epoch = epoch
        print(f'\t {best_vid_acc_1:0.3f} @ {best_epoch}, currently: {epoch:05d}:{vid_acc1:.3f}', end='\r')
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print_or_log(f'Training time {total_time_str}', logger=logger)
    return best_vid_acc_1, best_epoch, vid_acc1


def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
                         'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='ESC-50 Audio Classification')
    parser.register('type', 'bool', str2bool)

    ### DATA
    parser.add_argument(
        '--dataset',
        default='esc50',
        help='name of dataset',
        choices=['esc50', 'dcase2014']
    )
    parser.add_argument(
        '--fold',
        default='1,2,3,4,5',
        type=str,
        help='name of dataset'
    )
    parser.add_argument(
        '-j', '--workers',
        default=10,
        type=int,
        metavar='N',
        help='number of data loading workers (default: 16)'
    )

    ### MODEL
    parser.add_argument(
        '--cudadev',
        default='0',
        type=str,
        help='which gpu'
    )
    parser.add_argument(
        '--model',
        default='av_gdt',
        help='model',
        choices=['r2plus1d_18', 'av_gdt']
    )
    parser.add_argument(
        '--weights-path',
        default='./model_weights/vggsound_ckpt.pth',
        help='Path to weights file',
    )
    parser.add_argument(
        '--ckpt-epoch',
        default=0,
        type=int,
        help='Epoch of model checkpoint',
    )
    parser.add_argument(
        '--vid-base-arch',
        default='r2plus1d_18',
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch',
        default='resnet9',
        help='Audio Base Arch for A-V model',
        choices=['resnet18', 'resnet34', 'resnet9']
    )
    parser.add_argument(
        "--pretrained",
        type='bool',
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        '--use-mlp',
        default='True',
        type='bool',
        help='Use MLP projection head'
    )

    ### FINETUNE params
    parser.add_argument(
        "--feature-extract",
        type='bool',
        default='True',
        help="Use model as feature extractor; if False, fientune entire model",
    )
    parser.add_argument(
        "--recompute",
        type='bool',
        default='False',
        help="Use model as feature extractor; if False, fientune entire model",
    )
    ### TRAINING
    parser.add_argument(
        '--metaexp',
        default='experiment-name',
        type=str,
        help='name to store results'
    )
    parser.add_argument(
        '-b', '--batch-size',
        default=256,
        type=int
    )
    parser.add_argument(
        '--epochs',
        default=20000,
        type=int,
        metavar='N',
        help='number of total epochs to run'
    )
    parser.add_argument(
        '--pooltype',
        default='pool22-22-nopad',
        type=str,
    )
    parser.add_argument(
        '--head-lr',
        default=1e-4,
        type=float,
        help='initial learning rate'
    )
    parser.add_argument(
        '--base-lr',
        default=1e-4,
        type=float,
        help='initial learning rate'
    )
    parser.add_argument(
        '--wd', '--weight-decay',
        default=1e-4,
        type=float,
        metavar='W',
        help='weight decay (default: 1e-4)',
        dest='weight_decay'
    )
    parser.add_argument(
        '--lr-milestones',
        default='10000,15000',
        type=str,
        help='decrease lr on milestones (epochs)'
    )
    parser.add_argument(
        '--optim-name',
        default='adam',
        type=str,
        help='Name of optimizer',
        choices=['sgd', 'adam']
    )

    parser.add_argument(
        '--seconds',
        default=2,
        type=int
    )

    parser.add_argument(
        "--use-scaling",
        type='bool',
        default='False',
        help="Use LR scaling",
    )

    parser.add_argument(
        '--momentum',
        default=0.9,
        type=float,
        metavar='M',
        help='momentum'
    )

    parser.add_argument(
        "--use-scheduler",
        type='bool',
        default='True',
        help="Use LR scheduler",
    )
    parser.add_argument(
        '--lr-warmup-epochs',
        default=0,
        type=int,
        help='number of warmup epochs'
    )

    parser.add_argument(
        '--lr-gamma',
        default=0.1,
        type=float,
        help='decrease lr by a factor of lr-gamma'
    )

    ### LOGGING
    parser.add_argument(
        '--print-freq',
        default=50000,
        type=int,
        help='print frequency'
    )
    parser.add_argument(
        '--output-dir',
        default='.',
        help='path where to save'
    )

    ### CHECKPOINTING
    parser.add_argument(
        '--resume',
        default='',
        help='resume from checkpoint'
    )
    parser.add_argument(
        '--start-epoch',
        default=0, type=int,
        metavar='N',
        help='start epoch'
    )
    parser.add_argument(
        "--test-only",
        type='bool',
        default='False',
        help="Only test the model",
    )

    # distributed training parameters
    parser.add_argument(
        '--device',
        default='cuda',
        help='device'
    )
    parser.add_argument(
        '--distributed',
        type='bool',
        default='False',
        help="ddp mode",
    )
    parser.add_argument(
        '--dist-backend',
        default='nccl',
        type=str,
        help='distributed backend'
    )
    parser.add_argument(
        '--dist-url',
        default='env://',
        help='url used to set up distributed training'
    )
    parser.add_argument(
        '--world-size',
        default=1,
        type=int,
        help='number of distributed processes'
    )
    parser.add_argument(
        '--debug-slurm',
        type='bool',
        default='False',
        help="Debug SLURM",
    )
    parser.add_argument(
        '--local-rank',
        default=-1,
        type=int,
        help='Local rank of node')
    parser.add_argument(
        '--master-port',
        default=-1,
        type=int,
        help='Master port of Job'
    )

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cudadev
    # Init distributed mode
    if args.dataset == 'dcase2014':
        args.epochs = 500
    if torch.cuda.is_available():
        utils.init_distributed_mode(args)

    # Make output dir
    operation = 'feat_extract' if args.feature_extract else 'finetune'
    model_name = f'{operation}_{args.model}_pretrained_{args.pretrained}_ds_{args.dataset}_epochs_{args.epochs}_bsz_{args.batch_size}_optim_{args.optim_name}_lr_{args.head_lr}_scheduler_{args.use_scheduler}'
    args.output_dir = os.path.join(args.output_dir, model_name)
    if args.output_dir:
        utils.mkdir(args.output_dir)

    filename = 'logger.out'
    logger = utils.setup_logger(
        "Video_reader, classification",
        args.output_dir,
        True,
        logname=filename
    )

    # Set up tensorboard
    global_rank = 0
    is_master = True if global_rank == 0 else False
    writer = None
    # Log version information
    print_or_log(args, logger=logger)
    print_or_log(f"torch version: {torch.__version__}", logger=logger)
    print_or_log(f"torchvision version: {torchvision.__version__}", logger=logger)
    global GLOBALEXP
    GLOBALEXP = args.metaexp
    # Run over different folds
    best_accs_1 = []
    best_epochs = []
    folds = [int(fold) for fold in args.fold.split(',')]
    for fold in folds:
        args.fold = fold
        best_acc1, best_epoch, vid_acc1 = main(args, logger, writer)
        best_accs_1.append(best_acc1)
        best_epochs.append(best_epoch)
        print(f'Best Epochs {best_epochs}, Accs: {best_accs_1}')
    print(args)
    print()
    print("=" * 60)
    print(f'Best Epochs {best_epochs}, Accs: {best_accs_1}')
    print(f'({args.dataset}): Audio Acc@1 {np.mean(best_accs_1):.3f} ')
