import argparse
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision

from models.embed import MnistBase
from divergence.bregman import *
from divergence.ficnn import *
from divergence.baselines import *
from divergence.max_affine import *
from divergence.deepnorm_helper import DeepNormMetric, WideNormMetric

from utils.data import load_torch_data, split_and_normalize_as_tensors, train_val_split_tensors
from utils.models import get_embeddings, train_sgd, MetricModel
from utils.logging import TrainLogger
from utils.scoring import RetrievalMetrics

from datasets import BregMNIST

DEVICE = 'cuda'


def train_metric(metric, train_loader, optimizer):
    print('starting')
    metric.train()
    loss_fn = nn.MSELoss()

    for x1, x2, target in train_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        target = target.float()

        out = metric(x1, x2)
        loss = loss_fn(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def eval_metric(metric, data_loader):
    metric.eval()
    preds = []
    truth = []
    for x1, x2, target in data_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        target = target.float()

        out = metric(x1, x2)
        preds.append(out.detach())
        truth.append(target)

    preds = torch.cat(preds)
    truth = torch.cat(truth)

    return torch.mean((preds - truth) ** 2).item()


class MetricModel(nn.Module):
    def __init__(self, metric):
        super().__init__()
        self.embedding = MnistBase(out_dim=128)
        self.metric = metric
    
    def forward(self, x, y):
        phi_x = self.embedding(x)
        phi_y = self.embedding(y)
        return self.metric.pairwise_distance(phi_x, phi_y)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--div', choices=['euclid', 'xlogx'])
    parser.add_argument('--metric')
    parser.add_argument('--run', type=int, default=0)
    args = parser.parse_args()
    
    log = TrainLogger(
        log_dir='./results/bregmnist',
        name=f'bregmnist_{args.div}_{args.metric}_{args.run}'
    )

    if args.div == 'euclid':
        def div_fn(x, y):
            phi_x = x ** 2
            phi_y = y ** 2
            grad_div = 2 * y * (x - y)
            div = phi_x - phi_y - grad_div
            return div

    elif args.div == 'xlogx':
        def div_fn(x, y):
            phi_x = (x + 1) * np.log(x + 1)
            phi_y = (y + 1) * np.log(y + 1)
            grad_div = (np.log(y + 1) + 1) * (x - y)
            div = phi_x - phi_y - grad_div
            return div


    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])

    mnist_train_ds = torchvision.datasets.MNIST('/home/xxxx/data', train=True, download=True,
                                            transform=transform)
    mnist_test_ds = torchvision.datasets.MNIST('/home/xxxx/data', train=False, download=False,
                                            transform=transform)

    train_ds = BregMNIST(mnist_train_ds, div_fn)
    test_ds = BregMNIST(mnist_test_ds, div_fn)

    def worker_init_fn(worker_id):                                                          
        np.random.seed(np.random.get_state()[1][0] + worker_id)

    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=256, shuffle=True, num_workers=4,
        worker_init_fn=worker_init_fn)

    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=256, shuffle=False, num_workers=4,
        worker_init_fn=worker_init_fn)


    if args.metric == 'concat':
        metric = MetricModel(
            PairedInputDistance(128)
        ).to(DEVICE)

    if args.metric == 'mahalanobis':
        metric = MetricModel(
            MahalanobisDistance(128, 128)
        ).to(DEVICE)

    elif args.metric == 'widenorm':
        metric = MetricModel(
            WidenormDivergence(
                WideNormMetric(128, 32, 32, 5,
                                mode='maxavg'
                )
            )
        ).to(DEVICE)
    
    elif args.metric == 'deepnorm':
        metric = MetricModel(
            DeepnormDivergence(
                DeepNormMetric(128, [128, 128, 128],
                               activation=lambda: MaxReLUPairwiseActivation(128),
                               concave_activation_size=5,
                               mode='maxavg'
                )
            )
        ).to(DEVICE)

    elif args.metric == 'maxaffine':
        metric = MetricModel(
            MaxAffineDivergence(MaxAffineNet(128, K=50))
        ).to(DEVICE)

    elif args.metric == 'bregman':
        metric = MetricModel(
            BregmanDivergence(FICNN(128, 1, 2, 128, activation='softplus'))
        ).to(DEVICE)


    opt = torch.optim.Adam(metric.parameters(), lr=1e-3)

    for epoch in range(200):
        np.random.seed(epoch)

        train_metric(metric, train_loader, opt)

        if epoch % 1 == 0:
            tr_loss = eval_metric(metric, train_loader)
            te_loss = eval_metric(metric, test_loader)
            
            print('epoch ', epoch)
            print(tr_loss, te_loss)

            log.add('epoch', epoch)
            log.add('train_loss', tr_loss)
            log.add('test_loss', te_loss)
            log.export()

