""" Helper modules for benchmarking_utils SSL models """

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from sklearn.cluster import KMeans
from collections import defaultdict
from tqdm import tqdm

# code for kNN prediction from here:
# https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb
from benchmarking_utils.clustering_metrics import ncc_predict, ncc_predict_superclass


def predict_k_cluster_ncc(feature: torch.Tensor,
                          feature_bank: torch.Tensor,
                          feature_labels: torch.Tensor,
                          num_clusters: int) -> torch.Tensor:
    feature_bank = feature_bank.T  # .cpu()
    feature = feature.T

    # feature = feature.cpu()
    # cluster_ids_x, cluster_centers = kmeans(
    # 	X=feature_bank, num_clusters=num_clusters, distance='euclidean', device=torch.device('cpu')
    # )

    kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(feature_bank.cpu())
    cluster_ids_x = torch.tensor(kmeans.labels_)
    cluster_centers = torch.tensor(kmeans.cluster_centers_)
    feature = feature.cpu()

    # cluster_centers = cluster_centers.cuda()
    NCC_scores = [torch.norm(feature[i, :] - cluster_centers, dim=1) for i in range(feature.shape[0])]
    NCC_scores = torch.stack(NCC_scores)
    NCC_pred = torch.argmin(NCC_scores, dim=1)
    # NCC_acc = ((NCC_pred == cluster_ids_x.cuda()).sum() / NCC_pred.shape[0]).cpu()
    NCC_acc = ((NCC_pred == cluster_ids_x).sum() / NCC_pred.shape[0])

    N = num_clusters * [0]
    mean = num_clusters * [0]
    mean_s = num_clusters * [0]

    # COMPUTE CDNV
    for c in range(num_clusters):
        idxs = (cluster_ids_x == c).nonzero(as_tuple=True)[0]
        if len(idxs) == 0:
            continue

        h_c = feature[idxs, :]
        mean[c] += torch.sum(h_c, dim=0)
        N[c] += h_c.shape[0]
        mean_s[c] += torch.sum(torch.square(h_c))

    for c in range(num_clusters):
        idxs = (cluster_ids_x == c).nonzero(as_tuple=True)[0]
        if len(idxs) == 0:  # If no class-c in this batch
            continue

        h_c = feature[idxs, :]
        mean[c] += torch.sum(h_c, dim=0)
        N[c] += h_c.shape[0]
        mean_s[c] += torch.sum(torch.square(h_c))

    for c in range(num_clusters):
        mean[c] /= N[c]
        mean_s[c] /= N[c]

    avg_cdnv = 0
    total_num_pairs = num_clusters * (num_clusters - 1) / 2
    for class1 in range(num_clusters):
        for class2 in range(class1 + 1, num_clusters):
            variance1 = abs(mean_s[class1].item() - torch.sum(torch.square(mean[class1])).item())
            variance2 = abs(mean_s[class2].item() - torch.sum(torch.square(mean[class2])).item())
            variance_avg = (variance1 + variance2) / 2
            dist = torch.norm((mean[class1]) - (mean[class2])) ** 2
            dist = dist.item()
            cdnv = variance_avg / dist
            avg_cdnv += cdnv / total_num_pairs

    return NCC_acc, cluster_centers, avg_cdnv


def knn_predict(feature: torch.Tensor,
                feature_bank: torch.Tensor,
                feature_labels: torch.Tensor,
                num_classes: int,
                knn_k: int = 200,
                knn_t: float = 0.1) -> torch.Tensor:
    """Run kNN predictions on features based on a feature bank

	This method is commonly used to monitor performance of self-supervised
	learning methods.

	The default parameters are the ones
	used in https://arxiv.org/pdf/1805.01978v1.pdf.

	Args:
		feature:
			Tensor of shape [N, D] for which you want predictions
		feature_bank:
			Tensor of a database of features used for kNN
		feature_labels:
			Labels for the features in our feature_bank
		num_classes:
			Number of classes (e.g. `10` for CIFAR-10)
		knn_k:
			Number of k neighbors used for kNN
		knn_t:
			Temperature parameter to reweights similarities for kNN

	Returns:
		A tensor containing the kNN predictions

	Examples:
		>>> images, targets, _ = batch
		>>> feature = backbone(images).squeeze()
		>>> # we recommend to normalize the features
		>>> feature = F.normalize(feature, dim=1)
		>>> pred_labels = knn_predict(
		>>>     feature,
		>>>     feature_bank,
		>>>     targets_bank,
		>>>     num_classes=10,
		>>> )
	"""

    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(
        feature.size(0), -1), dim=-1, index=sim_indices)
    # we do a reweighting of the similarities
    sim_weight = (sim_weight / knn_t).exp()
    # counts for each class
    one_hot_label = torch.zeros(feature.size(
        0) * knn_k, num_classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(
        dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(
        0), -1, num_classes) * sim_weight.unsqueeze(dim=-1), dim=1)
    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels


class BenchmarkModule(pl.LightningModule):
    def __init__(self,
                 dataloader_kNN: DataLoader,
                 dataloader_test: DataLoader,
                 num_classes: int,
                 knn_k: int = 200,
                 knn_t: float = 0.1):
        super().__init__()
        self.backbone = nn.Module()
        self.max_knn_accuracy = 0.0
        self.max_nnc_accuracy = 0.0
        self.dataloader_kNN = dataloader_kNN
        self.dataloader_test = dataloader_test
        self.num_classes = num_classes
        self.knn_k = knn_k
        self.knn_t = knn_t

        # create dummy param to keep track of the device the model is using
        self.dummy_param = nn.Parameter(torch.empty(0))

    def training_epoch_end(self, outputs,
                           is_offline=False, from_layer=2, use_test=False,
                           mode='train'):
        # update feature bank at the end of each training epoch
        if not use_test:
            cur_loader = self.dataloader_kNN

            self.feature_bank = defaultdict(list)
            self.targets_bank = defaultdict(list)
            cur_feature_bank = self.feature_bank
            cur_target_bank = self.targets_bank
        else:
            cur_loader = self.dataloader_test
            self.feature_bank_test = defaultdict(list)
            self.targets_bank_test = defaultdict(list)
            cur_feature_bank = self.feature_bank_test
            cur_target_bank = self.targets_bank_test


            print(f"use_test:{use_test}, cur_loader:{len(cur_loader)}")
            print(f"starting training_epoch_end, epoch:{self.current_epoch}, use_test:{use_test}")

        if mode == 'train':
            print(f"step:{self.trainer.check_val_every_n_epoch}")
            step = self.trainer.check_val_every_n_epoch

            if self.current_epoch % step != 0 and \
                    (self.current_epoch + 1) % step != 0 and \
                    (self.current_epoch + 2) % step != 0:
                return

            print(f"made it past the condition!, self.current_epoch:{self.current_epoch}")
            print(f"self.current_epoch % step: {self.current_epoch % step}, "
                  f"(self.current_epoch + 1) % step:{(self.current_epoch + 1) % step}"
                  f"(self.current_epoch + 2) % step:{(self.current_epoch + 2) % step}")

        self.backbone.eval()
        with torch.no_grad():
            for data in tqdm(cur_loader, total=len(cur_loader)):
                img, target, _ = data
                img = img.to(self.dummy_param.device)
                target = target.to(self.dummy_param.device)
                features = self.backbone(img)[1][from_layer:]
                for idx, feature in enumerate(features):
                    feature = F.normalize(feature, dim=1)
                    cur_feature_bank[idx + from_layer].append(feature.cpu())
                    cur_target_bank[idx + from_layer].append(target.cpu())

                del features

        for key, feature in cur_feature_bank.items():
            cur_feature_bank[key] = torch.cat(feature, dim=0)
            cur_shape = cur_feature_bank[key].shape
            cur_result = cur_feature_bank[key].reshape(cur_shape[0], -1).t().contiguous()
            cur_feature_bank[key] = cur_result
            print(f"key:{key}, feature_bank:{cur_feature_bank[key].shape}")

        for key, target in cur_target_bank.items():
            cur_target_bank[key] = torch.cat(
                target, dim=0).t().contiguous()
        self.backbone.train()

    def validation_step(self, batch, batch_idx, from_layer=2):

        # we can only do kNN predictions once we have a feature bank
        print(f"\nIn validation step!: False, epoch:{self.current_epoch}.")


    def validation_epoch_end(self, outputs):
        device = self.dummy_param.device
        if outputs:
            print(f"outputs:{outputs}")
            print_res_ncc_superclass, print_res_ncc_orig, \
            print_res_cdnv_superclass, print_res_cdnv_orig = {}, {}, {}, {}
            # print_res_ncc_superclass = {'epoch': self.current_epoch}
            # print_res_cdnv = {'epoch': self.current_epoch}
            for key, value in outputs[0].items():
                ncc, _, cdnv = value
                if 'superclass' in key:
                    print_res_ncc_superclass.update({f'NCC_{key}_clusters': ncc})
                    print_res_cdnv_superclass.update({f'CDNV_{key}_clusters': cdnv})
                else:
                    print_res_ncc_orig.update({f'NCC_{key}_clusters': ncc})
                    print_res_cdnv_orig.update({f'CDNV_{key}_clusters': cdnv})

            print(print_res_ncc_superclass)
            print(print_res_ncc_orig)
            print(print_res_cdnv_superclass)
            print(print_res_cdnv_orig)

            names = ['NCC/superclass', 'NCC/orig', 'CDNV/superclass', 'CDNV/orig']
            dicts = [print_res_ncc_superclass, print_res_ncc_orig, print_res_cdnv_superclass, print_res_cdnv_orig]

            for name, cur_dict in zip(names, dicts):
                for key, value in cur_dict.items():
                    self.log(f'{name}_{key}',
                             value, on_epoch=True, sync_dist=True)