import argparse
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision

from models.embed import CifarBase
from divergence.bregman import *
from divergence.ficnn import *
from divergence.baselines import *
from divergence.max_affine import *
from divergence.deepnorm_helper import DeepNormMetric, WideNormMetric

from utils.data import load_torch_data, split_and_normalize_as_tensors, train_val_split_tensors
from utils.models import get_embeddings, train_sgd, MetricModel
from utils.logging import TrainLogger
from utils.scoring import RetrievalMetrics

from datasets import ClustCIFAR


DEVICE = 'cuda'


class DiffDistance(nn.Module):
    def __init__(self, input_dim=20):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, 100),
            nn.Softplus(),
            nn.Linear(100, 100),
            nn.Softplus(),
            nn.Linear(100, 1)
        )
        
    def forward(self, X1, X2):
        return self.net(torch.abs(X1 - X2))


class PairedInputDistance(nn.Module):
    def __init__(self, input_dim=20):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(2 * input_dim, 100),
            nn.Softplus(),
            nn.Linear(100, 100),
            nn.Softplus(),
            nn.Linear(100, 1)
        )
        
    def forward(self, X1, X2):
        X = torch.cat([X1, X2], dim=1)
        return self.net(X).flatten()


def train_metric(metric, train_loader, optimizer):
    print('starting')
    metric.train()
    loss_fn = nn.MSELoss()
    
    for x1, x2, target in train_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        target = target.float()

        optimizer.zero_grad()

        out = metric(x1, x2)
        loss = loss_fn(out, target)

        loss.backward()
        optimizer.step()


def eval_metric(metric, data_loader):
    metric.eval()
    preds = []
    truth = []
    for x1, x2, target in data_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        target = target.float()

        out = metric(x1, x2)
        preds.append(out.detach())
        truth.append(target)

    preds = torch.cat(preds)
    truth = torch.cat(truth)
    return torch.mean((preds - truth) ** 2).item()


from models.resnets import resnet20
class MetricModel(nn.Module):
    def __init__(self, metric):
        super().__init__()
        self.embedding = resnet20()
        self.embedding.linear = nn.Linear(64, 128)
        self.metric = metric
    
    def forward(self, x, y):
        phi_x = self.embedding(x)
        phi_y = self.embedding(y)
        return self.metric.pairwise_distance(phi_x, phi_y)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--metric')
    parser.add_argument('--run', type=int, default=0)
    args = parser.parse_args()
    
    log = TrainLogger(
        log_dir='./results/clustcifar',
        name=f'clustcifar_{args.metric}_{args.run}'
    )


    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    train_ds = torchvision.datasets.CIFAR10('/home/xxxx/data', train=True, download=True,
                                            transform=transform)
    test_ds = torchvision.datasets.CIFAR10('/home/xxxx/data', train=False, download=False,
                                            transform=transform)

    train_ds = ClustCIFAR(train_ds, {})
    test_ds = ClustCIFAR(test_ds, {})

    def worker_init_fn(worker_id):                                                          
        np.random.seed(np.random.get_state()[1][0] + worker_id)

    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=256, shuffle=True, num_workers=4,
        worker_init_fn=worker_init_fn)

    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=256, shuffle=False, num_workers=4,
        worker_init_fn=worker_init_fn)

    if args.metric == 'concat':
        metric = MetricModel(
            PairedInputDistance(128)
        ).to(DEVICE)

    if args.metric == 'mahalanobis':
        metric = MetricModel(
            MahalanobisDistance(128, 128)
        ).to(DEVICE)

    elif args.metric == 'widenorm':
        metric = MetricModel(
            WidenormDivergence(
                WideNormMetric(128, 32, 32, 5,
                                mode='maxavg'
                )
            )
        ).to(DEVICE)
    
    elif args.metric == 'deepnorm':
        metric = MetricModel(
            DeepnormDivergence(
                DeepNormMetric(128, [128, 128, 128],
                               activation=lambda: MaxReLUPairwiseActivation(128),
                               concave_activation_size=5,
                               mode='maxavg'
                )
            )
        ).to(DEVICE)

    elif args.metric == 'maxaffine':
        metric = MetricModel(
            MaxAffineDivergence(MaxAffineNet(128, K=50))
        ).to(DEVICE)

    elif args.metric == 'bregman':
        metric = MetricModel(
            BregmanDivergence(FICNN(128, 1, 2, 256, activation='softplus'))
        ).to(DEVICE)


    opt = torch.optim.Adam(metric.parameters(), lr=1e-3)

    for epoch in range(200):
        np.random.seed(epoch)

        train_metric(metric, train_loader, opt)

        if epoch % 1 == 0:
            tr_loss = eval_metric(metric, train_loader)
            te_loss = eval_metric(metric, test_loader)
            
            print('epoch ', epoch)
            print(tr_loss, te_loss)

            log.add('epoch', epoch)
            log.add('train_loss', tr_loss)
            log.add('test_loss', te_loss)
            log.export()

