import argparse
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision
from sklearn.datasets import make_spd_matrix

from divergence.bregman import *
from divergence.ficnn import *
from divergence.baselines import *
from divergence.max_affine import *
from divergence.deepnorm_helper import DeepNormMetric, WideNormMetric

from utils.logging import TrainLogger
from utils.linalg import gen_condition_corrmat
from pbdl.piecewise_linear_estimation import PBDL


DEVICE = 'cuda'


def compute_div(X1, X2, form):
    if form == 'euclidean':
        div = torch.norm(X1 - X2, dim=1).pow(2)

    elif form == 'mahalanobis':
        Q = make_spd_matrix(X1.size(1), random_state=0)
        div = torch.mul((X1 - X2) @ Q, (X1 - X2)).sum(-1)

    elif form == 'xlogx':
        phi_x = (X1 * torch.log(X1)).sum(-1)
        phi_y = (X2 * torch.log(X2)).sum(-1)
        grad_div = ((torch.log(X2) + 1) * (X1 - X2)).sum(-1)
        div = phi_x - phi_y - grad_div

    elif form == 'KL':
        p1, p2 = F.softmax(X1, dim=-1), F.softmax(X2, dim=-1)
        div = torch.sum(p1 * torch.log2(p1 / p2), dim=-1)

    return div


def make_data(N=10000, p=20, p_informative=2, noise_scale=0.0,
              covariance=None, form='euclidean', seed=None):
    if seed and isinstance(seed, int):
        torch.manual_seed(seed)
        np.random.seed(seed)

    X1 = torch.tensor(
        np.random.multivariate_normal(np.zeros(p), covariance, size=N)
    ).float()
    X2 = torch.tensor(
        np.random.multivariate_normal(np.zeros(p), covariance, size=N)
    ).float()

    if form == 'xlogx':
        X1, X2 = torch.abs(X1) + 0.001, torch.abs(X2) + 0.001

    y = compute_div(X1[:, :p_informative], X2[:, :p_informative], form)

    if noise_scale > 0:
        y += torch.randn(N) * math.sqrt(noise_scale) * y.std()
    return TensorDataset(X1, X2, y)


def train_metric(metric, train_loader, optimizer):
    metric.train()
    loss_fn = nn.MSELoss()

    for x1, x2, target in train_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        target = target.float()
        out = metric.pairwise_distance(x1, x2)
        loss = loss_fn(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def eval_metric(metric, data_loader):
    metric.eval()
    preds = []
    truth = []
    for x1, x2, target in data_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        target = target.float()

        out = metric.pairwise_distance(x1, x2)
        preds.append(out.detach())
        truth.append(target)

    preds = torch.cat(preds)
    truth = torch.cat(truth)

    return (torch.mean(torch.abs(preds - truth))).item()


def visualize(metric_mod, div_form, save_path):

    if div_form == 'xlogx':
        x = np.linspace(0.01, 0.3, 30)
        y = np.linspace(0.01, 0.3, 30)
    else:        
        x = np.linspace(-0.3, 0.3, 30)
        y = np.linspace(-0.3, 0.3, 30)
    X, Y = np.meshgrid(x, y)

    x1, x2 = meshgrid_to_batch(X, Y)
    x1, x2 = x1.to(DEVICE), x2.to(DEVICE)

    truth = div_to_meshgrid(compute_div(x1[:, :1], x2[:, :1], div_form))
    pred = div_to_meshgrid(metric_mod(x1, x2))

    contour_diagnostic(X, Y, truth, pred, save_path)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--div', choices=['euclidean', 'xlogx', 'mahalanobis', 'KL'])
    parser.add_argument('--metric')
    parser.add_argument('--inform', type=int, default=20)
    parser.add_argument('--noise', type=float, default=0.0)
    parser.add_argument('--condition', choices=['low', 'med', 'high'])
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()
    print(args)

    log = TrainLogger(
        log_dir='./results/div_fit',
        name=f'div_{args.div}_metric_{args.metric}_feats_{args.inform}_cond_{args.condition}_{args.seed}'
    )

    # make fixed test dataset
    np.random.seed(0)
    cov = gen_condition_corrmat(n_feats=20, condition=args.condition)

    test_ds = make_data(N=10000, p=20, p_informative=args.inform,
                         noise_scale=args.noise, covariance=cov,
                         form=args.div, seed=args.seed)
    test_loader = DataLoader(test_ds, batch_size=1000, shuffle=False)

    if args.metric == 'mahalanobis':
        metric = MahalanobisDistance(20, 20).to(DEVICE)

    elif args.metric == 'nbd':
        metric = BregmanDivergence(FICNN(20, 1, 2, 32, activation='softplus')).to(DEVICE)

    elif args.metric == 'deep-div':
        metric = MaxAffineDivergence(MaxAffineNet(20, K=20)).to(DEVICE)

    elif args.metric == 'widenorm':
        metric = WideNormMetric(20, 32, 32, 5, mode='maxavg').to(DEVICE)

    elif args.metric == 'deepnorm':
        metric = DeepNormMetric(20, [128, 128, 128],
                activation=lambda: MaxReLUPairwiseActivation(128),
                concave_activation_size=5,
                mode='maxavg').to(DEVICE)

    opt = torch.optim.Adam(metric.parameters(), lr=1e-3)

    for epoch in range(100):

        # sample new train dataset each epoch
        train_ds = make_data(N=50000, p=20, p_informative=args.inform,
                            noise_scale=0.0, covariance=cov,
                            form=args.div)
        train_loader = DataLoader(train_ds, batch_size=1000, shuffle=True)

        train_metric(metric, train_loader, opt)
        tr_loss = eval_metric(metric, train_loader)
        te_loss = eval_metric(metric, test_loader)
        
        print('epoch: ', epoch, 'train loss: ', tr_loss, 'test loss: ', te_loss)

        log.add('epoch', epoch)
        log.add('train_loss', tr_loss)
        log.add('test_loss', te_loss)
        log.export()

    if args.inform == 1:
        visualize(metric, args.div, f'./results/div_fit/_IMG_{args.div}_{args.metric}')
