import argparse
import os
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset

from pytorch_metric_learning import losses, miners, reducers

from divergence.bregman import *
from divergence.ficnn import *
from divergence.baselines import *
from divergence.max_affine import *
from divergence.deepnorm_helper import DeepNormMetric, WideNormMetric

from utils.data import load_data, split_and_normalize_as_tensors, train_val_split_tensors
from utils.models import gridsearch_model_params, train_sgd, MetricModel
from utils.logging import TrainLogger
from utils.scoring import RetrievalMetrics, ClusteringMetrics

from pbdl.piecewise_linear_estimation import PBDL


DATASETS = ['abalone', 'car', 'iris', 'wine', 'balance-scale', 'transfusion']
METHODS = ['pbdl', 'deep-div', 'nbd', 'euclidean', 'mahalanobis', 'deepnorm', 'widenorm']

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', '-d', choices=DATASETS)
parser.add_argument('--method', '-m', choices=METHODS)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

logger = TrainLogger(
    name=f'log_div_{args.dataset}_{args.method}_{args.seed}',
    log_dir='./results/classify')


def validation_tuner(model, X_train, y_train):
    print('Starting validation tuning: ')
    batch_size = 256
    epochs = [125, 250, 500]
    lrs = [1e-3, 2e-3, 5e-3]

    X_tr, X_val, y_tr, y_val = train_val_split_tensors(
        X_train, y_train, val_size=0.25
    )
    train_ds = TensorDataset(X_tr, y_tr)
    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=batch_size, shuffle=True
    )
    torch.save(model.state_dict(), 'model_base_to_tune.pt')

    def objective(model, epochs, lr):
        # reset model to original initialization
        model.load_state_dict(torch.load('model_base_to_tune.pt'))
        reducer = reducers.ThresholdReducer(low=0)
        loss_func = losses.TripletMarginLoss(
            margin=0.2, distance=model.metric, reducer=reducer)
        mining_func = miners.TripletMarginMiner(
            margin=0.2, distance=model.metric, type_of_triplets='all')
    
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        for epoch in range(epochs):
            train_sgd(
                train_loader, model, [optimizer], loss_func,
                mining_func, device, verbose=False
            )

        metrics = RetrievalMetrics(
            model.metric, X_val, y_val, X_tr, y_tr
        )
        return metrics.mean_average_precision()

    best_score, best_params = gridsearch_model_params(
        objective,
        {'model': [model], 'epochs': epochs, 'lr': lrs},
        verbose=True
    )
    print(f'Best val search: {best_score}, {best_params}')

    # os.remove('model_base_to_tune.pt')
    return best_params


if __name__ == '__main__':

    X, y = load_data(args.dataset)
    N_FEAT = X.shape[1]

    X_train, X_test, y_train, y_test = split_and_normalize_as_tensors(
        X, y, test_size=0.25, seed=args.seed
    )
    X_train, X_test, y_train, y_test = \
        [arr.to(device) for arr in [X_train, X_test, y_train, y_test]]

    if args.method == 'pbdl':
        model = PBDL(device=device)
        model.fit(X_train, y_train)

        tasks = ['knn', 'map', 'auc']
        for task in tasks:
            score = model.score(X_test, y_test, X_train, y_train, task=task)
            logger.add(f'final_{task}', score)

        metrics = ClusteringMetrics(
            model.bregman_div, X_test, y_test
        )
        logger.add(f'final_purity', metrics.purity())
        logger.add(f'final_rand', metrics.rand_index())
        logger.export()

    elif args.method == 'euclidean':
        model = MetricModel(
            embedding=nn.Identity(),
            metric=EuclideanDistance(normalize_embeddings=False)
        )

        metrics = RetrievalMetrics(
            model.metric, X_test, y_test, X_train, y_train
        )
        logger.add(f'final_knn', metrics.knn_accuracy(5))
        logger.add(f'final_map', metrics.mean_average_precision())
        logger.add(f'final_auc', metrics.area_under_curve())

        metrics = ClusteringMetrics(
            model.metric.pairwise_distance, X_test, y_test
        )
        logger.add(f'final_purity', metrics.purity())
        logger.add(f'final_rand', metrics.rand_index())
        logger.export()

    else:
        if args.method == 'nbd':
            model = MetricModel(
                embedding=nn.Identity(),
                metric=BregmanDivergence(
                    FICNN(N_FEAT, 1, 2, 32, activation='softplus', nonneg_constraint='abs'),
                    normalize_embeddings=False
                )
            )

        elif args.method == 'deep-div':
            model = MetricModel(
                embedding=nn.Identity(),
                metric=MaxAffineDivergence(
                    MaxAffineNet(N_FEAT, K=50), normalize_embeddings=False)
            )

        elif args.method == 'widenorm':
            model = MetricModel(
                embedding=nn.Identity(),
                metric=WidenormDivergence(
                    WideNormMetric(N_FEAT, 32, 32, 5,
                                   mode='maxavg'
                    ),
                    normalize_embeddings=False
                )
            )

        elif args.method == 'deepnorm':
            model = MetricModel(
                embedding=nn.Identity(),
                metric=DeepnormDivergence(
                    DeepNormMetric(N_FEAT, [128, 128, 128],
                                   activation=lambda: MaxReLUPairwiseActivation(128),
                                   concave_activation_size=5,
                                   mode='maxavg'
                    ),
                    normalize_embeddings=False
                )
            )

        elif args.method == 'mahalanobis':
            model = MetricModel(
                embedding=nn.Identity(),
                metric=MahalanobisDistance(
                    N_FEAT, N_FEAT, normalize_embeddings=False)
            )

        # tune optimization parameters using train-val split
        best_params = validation_tuner(model, X_train, y_train)
        model.load_state_dict(torch.load('model_base_to_tune.pt'))

        BATCH_SIZE = 256
        EPOCHS = best_params['epochs']
        LR = best_params['lr']

        train_ds = TensorDataset(X_train, y_train)
        train_loader = torch.utils.data.DataLoader(
            train_ds, batch_size=BATCH_SIZE, shuffle=True
        )

        reducer = reducers.ThresholdReducer(low=0)
        loss_func = losses.TripletMarginLoss(margin=0.2, distance=model.metric, reducer=reducer)
        mining_func = miners.TripletMarginMiner(margin=0.2, distance=model.metric, type_of_triplets='all')
        
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)

        for epoch in range(EPOCHS):
            train_sgd(train_loader, model, [optimizer], loss_func, mining_func, device)

            if epoch % 5 == 0:
                metrics = RetrievalMetrics(
                    model.metric, X_test, y_test, X_train, y_train
                )
                logger.add('epoch', epoch)
                logger.add(f'knn', metrics.knn_accuracy(5))
                logger.add(f'map', metrics.mean_average_precision())
                logger.add(f'auc', metrics.area_under_curve())
                logger.export()

        metrics = RetrievalMetrics(
            model.metric, X_test, y_test, X_train, y_train
        )
        logger.add(f'final_knn', metrics.knn_accuracy(5))
        logger.add(f'final_map', metrics.mean_average_precision())
        logger.add(f'final_auc', metrics.area_under_curve())

        metrics = ClusteringMetrics(
            model.metric.pairwise_distance, X_test, y_test
        )
        logger.add(f'final_purity', metrics.purity())
        logger.add(f'final_rand', metrics.rand_index())
        logger.export()
