import argparse
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from pytorch_metric_learning import losses, miners, reducers

from divergence.bregman import *
from divergence.ficnn import *
from divergence.baselines import *
from divergence.max_affine import *
from divergence.deepnorm_helper import DeepNormMetric, WideNormMetric

from utils.data import load_torch_data, split_and_normalize_as_tensors, train_val_split_tensors
from utils.models import get_embeddings, train_sgd, MetricModel
from utils.logging import TrainLogger
from utils.scoring import RetrievalMetrics, ClusteringMetrics
from utils.mixture import generate_mixture

from pbdl.piecewise_linear_estimation import PBDL


DISTRIBUTIONS = ['gaussian', 'multinomial', 'exponential']
METHODS = ['deep-div', 'nbd', 'euclidean', 'mahalanobis', 'pbdl']

parser = argparse.ArgumentParser()
parser.add_argument('--dist', '-d', choices=DISTRIBUTIONS)
parser.add_argument('--method', '-m', choices=METHODS)
parser.add_argument('--mix', action='store_true',
    help='Apply mixing matrix to features for embedding learning')

parser.add_argument('--normalize', action='store_true')
parser.add_argument('--embedding', default='none', choices=['none', 'linear', 'fnn'])
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

logger = TrainLogger(
    name=f'log_dist_{args.dist}_{args.method}_{args.embedding}_mix_{args.mix}_norm_{args.normalize}_{args.seed}',
    log_dir='./results/mixture')


def test(test_loader, model, device):
    model.eval()
    test_embeddings, test_labels = get_embeddings(test_loader, model.embedding, device)
    if args.normalize:
        test_embeddings /= torch.norm(test_embeddings, dim=1, keepdim=True)

    # from pdb import set_trace; set_trace()
    metrics = ClusteringMetrics(
        model.metric.pairwise_distance, test_embeddings, test_labels.flatten()
    )
    logger.add(f'purity', metrics.purity())
    logger.add(f'rand', metrics.rand_index())
    logger.export()


if __name__ == '__main__':

    BATCH_SIZE = 128
    EPOCHS = 200
    LR = 1e-3
    N_FEAT = 10

    X, y = generate_mixture(1000, 5, n_dim=N_FEAT, family=args.dist)

    if args.mix:
        rng = np.random.default_rng(0)
        Q = rng.uniform(-1, 1, size=(N_FEAT, N_FEAT))
        X = X @ Q

    X_train, X_test, y_train, y_test = split_and_normalize_as_tensors(
        X, y, test_size=0.25, seed=args.seed
    )
    X_train, X_test, y_train, y_test = \
        [arr.to(device) for arr in [X_train, X_test, y_train, y_test]]

    if args.embedding:
        embedding = nn.Linear(N_FEAT, N_FEAT)
    else:
        embedding = nn.Identity()


    if args.method == 'pbdl':
        model = PBDL(device=device)
        model.fit(X_train, y_train)

        tasks = ['knn', 'map', 'auc']
        for task in tasks:
            score = model.score(X_test, y_test, X_train, y_train, task=task)
            logger.add(f'final_{task}', score)

        metrics = ClusteringMetrics(
            model.bregman_div, X_test, y_test
        )
        logger.add(f'purity', metrics.purity())
        logger.add(f'rand', metrics.rand_index())
        logger.export()

    elif args.method == 'euclidean':
        model = MetricModel(
            embedding=embedding,
            metric=EuclideanDistance(normalize_embeddings=False)
        )

        metrics = ClusteringMetrics(
            model.metric.pairwise_distance, X_test, y_test
        )
        logger.add(f'purity', metrics.purity())
        logger.add(f'rand', metrics.rand_index())
        logger.export()

    else:
        if args.method == 'nbd':
            model = MetricModel(
                embedding=embedding,
                metric=BregmanDivergence(
                    FICNN(N_FEAT, 1, 2, 32, activation='softplus', nonneg_constraint='abs'),
                    normalize_embeddings=args.normalize
                )
            )

        elif args.method == 'deep-div':
            model = MetricModel(
                embedding=embedding,
                metric=MaxAffineDivergence(
                    MaxAffineNet(N_FEAT, K=50),
                    normalize_embeddings=args.normalize)
            )

        elif args.method == 'mahalanobis':
            model = MetricModel(
                embedding=embedding,
                metric=MahalanobisDistance(
                    N_FEAT, N_FEAT,
                    normalize_embeddings=args.normalize)
            )

        train_ds = TensorDataset(X_train, y_train)
        train_loader = torch.utils.data.DataLoader(
            train_ds, batch_size=BATCH_SIZE, shuffle=True
        )
        test_ds = TensorDataset(X_test, y_test)
        test_loader = torch.utils.data.DataLoader(
            test_ds, batch_size=BATCH_SIZE, shuffle=False
        )

        reducer = reducers.ThresholdReducer(low=0)
        loss_func = losses.TripletMarginLoss(margin=0.2, distance=model.metric, reducer=reducer)
        mining_func = miners.TripletMarginMiner(margin=0.2, distance=model.metric, type_of_triplets='all')
        
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)

        for epoch in range(EPOCHS):
            train_sgd(train_loader, model, [optimizer], loss_func, mining_func, device)
            if epoch % 5 == 1:
                logger.add('epoch', epoch)
                test(test_loader, model, device)
                logger.export()

        logger.add('epoch', 'final')
        test(test_loader, model, device)
        logger.export()
