import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision

from models.embed import CropdistEmbed
from divergence.bregman import *
from divergence.ficnn import *
from divergence.baselines import *
from divergence.max_affine import *
from divergence.deepnorm_helper import DeepNormMetric, WideNormMetric

from utils.data import load_torch_data, split_and_normalize_as_tensors, train_val_split_tensors
from utils.models import get_embeddings, train_sgd, MetricModel
from utils.logging import TrainLogger
from utils.scoring import RetrievalMetrics

from datasets import CropDist


DEVICE = 'cuda'


def train_metric(metric, train_loader, optimizer):
    print('starting')
    metric.train()
    loss_fn = nn.L1Loss()

    for x1, x2, target in train_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        target = target.float()

        out = metric(x1, x2)

        loss = loss_fn(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def eval_metric(metric, data_loader):
    metric.eval()
    preds = []
    truth = []
    for x1, x2, target in data_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        target = target.float()

        out = metric(x1, x2)
        preds.append(out.detach())
        truth.append(target)

    preds = torch.cat(preds)
    truth = torch.cat(truth)
    return torch.mean((preds - truth) ** 2).item()


class MetricModel(nn.Module):
    def __init__(self, metric):
        super().__init__()
        self.embedding = CropdistEmbed(out_dim=128)
        self.metric = metric
    
    def forward(self, x, y):
        phi_x = self.embedding(x)
        phi_y = self.embedding(y)
        return self.metric.pairwise_distance(phi_x, phi_y)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--metric')
    parser.add_argument('--run', type=int, default=0)
    args = parser.parse_args()


    log = TrainLogger(
        log_dir='./results/cropdist',
        name=f'cropdist_{args.metric}_{args.run}'
    )

    train_ds = CropDist(split='train', img_width=64, resize=72)
    test_ds = CropDist(split='test', img_width=64, resize=72)

    def worker_init_fn(worker_id):                                                          
        np.random.seed(np.random.get_state()[1][0] + worker_id)

    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=128, shuffle=True, num_workers=4,
        worker_init_fn=worker_init_fn)

    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=128, shuffle=False, num_workers=4,
        worker_init_fn=worker_init_fn)


    if args.metric == 'mahalanobis':
        metric = MetricModel(
            MahalanobisDistance(128, 128)
        ).to(DEVICE)

    elif args.metric == 'widenorm':
        metric = MetricModel(
            WidenormDivergence(
                WideNormMetric(128, 32, 32, 5,
                                mode='maxavg'
                )
            )
        ).to(DEVICE)
    
    elif args.metric == 'deepnorm':
        metric = MetricModel(
            DeepnormDivergence(
                DeepNormMetric(128, [128, 128, 128],
                               activation=lambda: MaxReLUPairwiseActivation(128),
                               concave_activation_size=5,
                               mode='maxavg'
                )
            )
        ).to(DEVICE)

    elif args.metric == 'maxaffine':
        metric = MetricModel(
            MaxAffineDivergence(MaxAffineNet(128, K=50))
        ).to(DEVICE)

    elif args.metric == 'bregman':
        metric = MetricModel(
            BregmanDivergence(FICNN(128, 1, 2, 128, activation='softplus'))
        ).to(DEVICE)

    opt = torch.optim.Adam(metric.parameters(), lr=5e-4)

    for epoch in range(100):
        np.random.seed(epoch)

        train_metric(metric, train_loader, opt)

        if epoch % 1 == 0:
            tr_loss = eval_metric(metric, train_loader)
            te_loss = eval_metric(metric, test_loader)
            
            print('epoch ', epoch)
            print(tr_loss, te_loss)

            log.add('epoch', epoch)
            log.add('train_loss', tr_loss)
            log.add('test_loss', te_loss)
            log.export()

            print(metric.beta.item())
