import argparse
import copy
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from pytorch_metric_learning import losses, miners, reducers

from divergence.bregman import *
from divergence.ficnn import *
from divergence.baselines import *
from divergence.max_affine import *
from divergence.deepnorm_helper import DeepNormMetric, WideNormMetric

from models.embed import MnistBase
from models.resnets import ResNet, BasicBlock

from utils.data import load_torch_data, split_and_normalize_as_tensors, train_val_split_tensors
from utils.models import get_embeddings, train_sgd, MetricModel
from utils.logging import TrainLogger
from utils.scoring import RetrievalMetrics


DATASETS = ['svhn', 'stl10', 'cifar10']
METHODS = ['deep-div', 'nbd', 'euclidean', 'deepnorm', 'widenorm']

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', '-d', choices=DATASETS)
parser.add_argument('--method', '-m', choices=METHODS)

parser.add_argument('--normalize', action='store_true')
parser.add_argument('--contrast', choices=['triplet', 'contrastive'], default='triplet')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--emb_dim', type=int, default=32)
parser.add_argument('--optim', default='adam')

parser.add_argument('--eval', choices=['valid', 'test'], default='valid')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log', action='store_true')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

logger = TrainLogger(
    name=f'div_{args.dataset}_{args.method}_norm_{args.normalize}_{args.seed}',
    log_dir='./results/deep_clf')


def test(train_loader, test_loader, model, device, tag=''):
    model.eval()
    train_embeddings, train_labels = get_embeddings(train_loader, model.embedding, device)
    test_embeddings, test_labels = get_embeddings(test_loader, model.embedding, device)
    if args.normalize:
        train_embeddings /= torch.norm(train_embeddings, dim=1, keepdim=True)
        test_embeddings /= torch.norm(test_embeddings, dim=1, keepdim=True)

    metrics = RetrievalMetrics(model.metric,
                               test_embeddings,
                               test_labels,
                               train_embeddings,
                               train_labels,
                               tag=tag
                               )

    for k in [1, 3, 5, 10]:
        logger.add(f'knn@{k}', metrics.knn_accuracy(k))
    logger.add('prec@10', metrics.precision_at_k(10))
    logger.add('map@10', metrics.map_at_k(10))

    print('1-NN: ', metrics.knn_accuracy(1))
    print('Prec@10: ', metrics.precision_at_k(10))
    print('MAP@10: ', metrics.map_at_k(10))


if __name__ == '__main__':

    BATCH_SIZE = args.batch_size
    EPOCHS = args.epochs
    EMB_DIM = args.emb_dim
    LR = args.lr

    train_ds, test_ds = load_torch_data(args.dataset, eval_on=args.eval)

    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=3)
    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=BATCH_SIZE, num_workers=3)

    # load embedding model
    if args.dataset in ['mnist', 'fmnist']:
        embedding = MnistBase(out_dim=EMB_DIM)
    else:
        import torchvision
        embedding = torchvision.models.resnet18(pretrained=True)
        embedding.fc = nn.Linear(512, EMB_DIM)


    if args.method == 'nbd':
        model = MetricModel(
            embedding=embedding,
            metric=BregmanDivergence(
                FICNN(EMB_DIM, 1, 2, 32, activation='softplus', nonneg_constraint='abs'),
                normalize_embeddings=args.normalize)
        )

    elif args.method == 'deep-div':
        model = MetricModel(
            embedding=embedding,
            metric=MaxAffineDivergence(
                MaxAffineNet(EMB_DIM, K=50, linear_subnet=False),
                normalize_embeddings=args.normalize)
        )

    elif args.method == 'widenorm':
        model = MetricModel(
            embedding=embedding,
            metric=WidenormDivergence(
                WideNormMetric(EMB_DIM, 32, 32, 5,
                                mode='maxavg'
                ),
                normalize_embeddings=args.normalize
            )
        )

    elif args.method == 'deepnorm':
        model = MetricModel(
            embedding=embedding,
            metric=DeepnormDivergence(
                DeepNormMetric(EMB_DIM, [128, 128, 128],
                               activation=lambda: MaxReLUPairwiseActivation(128),
                               concave_activation_size=5,
                               mode='maxavg'
                ),
                normalize_embeddings=args.normalize
            )
        )

    elif args.method == 'mahalanobis':
        model = MetricModel(
            embedding=embedding,
            metric=MahalanobisDistance(
                EMB_DIM, EMB_DIM, normalize_embeddings=args.normalize)
        )

    elif args.method == 'euclidean':
        model = MetricModel(
            embedding=embedding,
            metric=EuclideanDistance(normalize_embeddings=args.normalize)
        )

    elif args.method == 'l2':
        model = MetricModel(
            embedding=embedding,
            metric=EuclideanDistance(
                squared=True, normalize_embeddings=args.normalize
            )
        )

    reducer = reducers.ThresholdReducer(low=0)
    loss_func = losses.TripletMarginLoss(margin=0.2, distance=model.metric, reducer=reducer)
    mining_func = miners.TripletMarginMiner(margin=0.2, distance=model.metric, type_of_triplets='all')

    model = model.to(device)
    if args.optim == 'adam':
        optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    elif args.optim == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=LR)

    for epoch in range(EPOCHS):
        logger.add('epoch', epoch)
        train_sgd(train_loader, model, [optimizer], loss_func, mining_func, device)
        
        tag = f'{args.dataset}_{args.method}_norm_{epoch}' if args.normalize else f'{args.dataset}_{args.method}_{epoch}'
        metrics = test(train_loader, test_loader, model, device, tag=tag)
        
        if args.log:
            logger.export()
