from __future__ import print_function, absolute_import
import time
from collections import OrderedDict, defaultdict
import numpy as np
import torch
from .utils.meters import AverageMeter
import torch.nn.functional as F

def extract_features_labels(model, data_loaders):
    model.eval()
    label_features = defaultdict(list)
    for i, data_loader in enumerate(data_loaders):
        if not i == len(data_loaders) - 1:
            with torch.no_grad():
                for i, (imgs, labels) in enumerate(data_loader):
                    imgs = imgs.cuda()
                    labels = labels.cuda()
                    outs = model.base(imgs)
                    outs = model.gap(outs)
                    outs = outs.view(outs.size(0), -1)
                    features = F.normalize(outs)
                    batch_size = imgs.size(0)

                    for i in range(batch_size):
                        label = labels[i]
                        feature = features[i].reshape(1, -1)
                        label_features[int(label)].append(feature)
        else:
            with torch.no_grad():
                for i, (imgs, _) in enumerate(data_loader):
                    imgs = imgs.cuda()
                    outs = model.base(imgs)
                    outs = model.gap(outs)
                    outs = outs.view(outs.size(0), -1)
                    features = F.normalize(outs)
                    bn_x = model.feat_bn(outs)
                    cls_score = model.classifier(bn_x)
                    batch_size = imgs.size(0)
                    preds = cls_score.argmax(dim=1, keepdim=True)

                    for i in range(batch_size):
                        pred = preds[i]
                        feature = features[i].reshape(1, -1)
                        label_features[int(pred)].append(feature)


    label_features = sorted(label_features.items(), key=lambda kv: (kv[0], kv[1]))

    return label_features


def pairwise_distance(features, query=None, gallery=None, metric=None):
    if query is None and gallery is None:
        n = len(features)
        x = torch.cat(list(features.values()))
        x = x.view(n, -1)
        if metric is not None:
            x = metric.transform(x)
        dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2
        dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t())
        return dist_m

    x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0)
    y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0)
    m, n = x.size(0), y.size(0)
    x = x.view(m, -1)
    y = y.view(n, -1)
    if metric is not None:
        x = metric.transform(x)
        y = metric.transform(y)
    dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
           torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
    dist_m.addmm_(1, -2, x, y.t())
    return dist_m, x.numpy(), y.numpy()

