import copy
import random
import io
from itertools import zip_longest
from tqdm.auto import tqdm
from omegaconf import OmegaConf

from PIL import Image
import matplotlib.pyplot as plt

import lightning.pytorch as pl
from lightning.pytorch.utilities import CombinedLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import multivariate_normal
from sklearn.neighbors import KernelDensity
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor

import faiss
import faiss.contrib.torch_utils

from models.encoder import ResNetEncoder
from utils.optimizer import *

from configs.enums import *

from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    roc_curve,
    balanced_accuracy_score,
    average_precision_score,
)
import scipy

from data.data import get_datasets
from data.transforms import *
from torchvision.utils import make_grid
import torchvision.transforms as transforms
import kornia.augmentation as K
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader

from hydra.utils import instantiate
import multiprocessing as mp


class ContextContrastingADModel(pl.LightningModule):
    def __init__(self, cfg, data_path=None):
        super().__init__()
        self.save_hyperparameters()
        if data_path is not None:
            cfg.dataset.data_path = data_path
        self.cfg = cfg
        self.configure_datasets()
        self.configure_model()
        self.configure_validation_step_outputs()

    @torch.no_grad()
    def get_test_augmentation_embeddings(self, x, identity_augmentation=False):
        n_augs = self.cfg.augmentation.test_time_augmentations
        n_context_augs = self.cfg.loss.n_context_augs
        # Fix seed for test time augmentations
        random.seed(self.cfg.seed)
        torch.manual_seed(self.cfg.seed)

        # Create n_augs many different context augmentation parameters
        context_aug_params = []
        if self.cfg.loss.augmentation_class.augmentation_type == "kornia":
            for _ in range(n_augs):
                found_new_params = False
                alarm_counter = 1000
                while not found_new_params:
                    alarm_counter -= 1
                    new_params = self.test_context_augmentation.forward_parameters(
                        x.shape
                    )
                    found_new_params = True
                    for params in context_aug_params[-(n_context_augs - 1) :]:
                        if (
                            any_equal_augmentation_params(params, new_params)
                            and not n_context_augs == 1
                        ):
                            found_new_params = False
                    if alarm_counter < 0 and not found_new_params:
                        raise ValueError(
                            "Could not find new context augmentation parameters, please decrease the number of test time augmentations or increase the number of context augmentations."
                        )
                context_aug_params.append(new_params)
        # Augment samples with simclr augmentations and context augmentations
        augmented_xs = []
        context_labels = []
        for i in range(n_augs):
            labels = torch.ones(x.shape[0]) * (i % n_context_augs)
            if self.cfg.loss.augmentation_class.augmentation_type == "kornia":
                augmented_x = self.test_context_augmentation(
                    x, params=context_aug_params[i]
                )
            elif self.cfg.loss.augmentation_class.augmentation_type == "custom":
                augmented_x = self.test_context_augmentation(
                    x, extra_args={"labels": labels}
                )
            if not identity_augmentation:
                augmented_x = self.test_augmentation(augmented_x)
            augmented_x = self.normalize(augmented_x)
            augmented_xs.append(augmented_x)
            context_labels.append(labels)
        augmented_xs = torch.cat(augmented_xs)
        context_labels = torch.cat(context_labels)

        # Get embeddings of augmented features
        projection, features = self(augmented_xs, context_labels)
        for projection_type in ["content", "context"]:
            projection[projection_type] = projection[projection_type][:, None].chunk(
                n_augs
            )
            projection[projection_type] = torch.cat(projection[projection_type], dim=1)
        features = features.chunk(n_augs)
        features = torch.cat(features, dim=1)
        return (
            augmented_xs.chunk(n_augs),
            projection,
            features,
        )

    def forward(self, x, context_labels):
        z = self.encoder(x)
        if self.cfg.loss.get("hierarchy", False):
            z = z[0]
        projection = self.criterion.get_projection(z, context_labels)
        return projection, z[:, None]

    def create_index(self, embeddings, index_type="l2"):
        d = embeddings.shape[-1]
        if self.cfg.get("index_pca_dim", None) is not None:
            d_orig = d
            d = self.cfg.index_pca_dim
        if index_type == "l2":
            index = faiss.IndexFlatL2(d)
        elif index_type == "cos":
            index = faiss.IndexFlatIP(d)
            embeddings = F.normalize(embeddings, dim=-1)
        if self.cfg.approximate_index:
            index = faiss.IndexIVFFlat(index, d, 100)
            index.train(embeddings.contiguous())
        if self.cfg.get("index_pca_dim", None) is not None:
            pca = faiss.PCAMatrix(d_orig, self.cfg.index_pca_dim)
            index = faiss.IndexPreTransform(pca, index)
            index.train(embeddings.contiguous())
        index.add(embeddings.contiguous())
        return index

    @torch.no_grad()
    def prepare_test_mode(self, init=False, print_results=False):
        # Compute embeddings of training data without augmentation
        self.compute_normal_embeddings(init=init)
        # Fit GMM to normal embeddings
        self.context_dists = []
        for i in tqdm(range(self.normal_embeddings.shape[1]), leave=False):
            curr_features = F.normalize(self.normal_embeddings[:, i], dim=-1)
            self.context_dists.append(
                multivariate_normal(
                    mean=curr_features.mean(dim=0),
                    cov=curr_features.T.cov(),
                    allow_singular=True,
                )
                # LocalOutlierFactor(
                #     # metric="cosine",
                #     n_neighbors=40,
                #     novelty=True,
                #     contamination=1e-3,
                #     algorithm="kd_tree",
                #     n_jobs=self.cfg.num_workers,
                # ).fit(curr_features)
            )

        # Log ratio between in-context and out-of-context distances of training samples
        self.compute_context_distance_ratios(print_results=print_results)

        # Update search index for nearest neighbor computations
        self.l2_index, self.cos_index = [], []
        self.l2_content_projections_index, self.cos_content_projections_index = [], []
        self.l2_context_projections_index, self.cos_context_projections_index = [], []
        # Create one index for every test time augmentation
        for i in tqdm(
            range(self.normal_embeddings.shape[1]),
            desc="Creating search indices",
            leave=False,
        ):
            self.l2_index.append(
                self.create_index(
                    self.normal_embeddings[:, i],
                    index_type="l2",
                )
            )
            self.cos_index.append(
                self.create_index(
                    self.normal_embeddings[:, i],
                    index_type="cos",
                )
            )
            normal_context_projections = self.normal_projections["context"][:, i]
            self.l2_context_projections_index.append(
                self.create_index(
                    normal_context_projections,
                    index_type="l2",
                )
            )
            self.cos_context_projections_index.append(
                self.create_index(
                    normal_context_projections,
                    index_type="cos",
                )
            )
            normal_content_projections = self.normal_projections["content"][
                :, i
            ].contiguous()
            self.l2_content_projections_index.append(
                self.create_index(
                    normal_content_projections,
                    index_type="l2",
                )
            )
            self.cos_content_projections_index.append(
                self.create_index(
                    normal_content_projections,
                    index_type="cos",
                )
            )

    @torch.no_grad()
    def on_fit_start(self):
        return

    @torch.no_grad()
    def train_augment(self, x):
        # Duplicate x to ensure each sample is in different contexts
        n = x.shape[0]
        x = x.repeat(2, 1, 1, 1)
        # if self.cfg.augmentation.align_content_augmentations:
        #     x = self.augmentations(x).repeat(2, 1, 1, 1)
        #     n = 2 * n

        # Generate context augmentation parameters
        n_contexts = self.cfg.loss.n_context_augs
        n_contexts_per_label = scipy.stats.multivariate_hypergeom.rvs(
            [n for _ in range(n_contexts)], 2 * n
        )
        # Shuffle context labels to ensure there is no bias for content contrasting
        labels = torch.randperm(n_contexts)
        context_labels = [
            labels[i] * torch.ones(n_samples_in_context)
            for i, n_samples_in_context in enumerate(n_contexts_per_label)
        ]
        context_labels = torch.cat(context_labels)
        # Aggregate sampled transformation parameters and apply to x
        if self.cfg.loss.augmentation_class.name == "random_rotation":
            final_rotation_params = self.context_augmentations.forward_parameters(
                (2 * n, x.shape[1], x.shape[2], x.shape[3])
            )
            if self.cfg.loss.augmentation_class.max_180_degrees and n_contexts != 1:
                final_rotation_params[0][1]["degrees"] = torch.cat(
                    [
                        torch.tensor(i / (n_contexts - 1)).repeat(n_samples_in_context)
                        * 180
                        for i, n_samples_in_context in enumerate(n_contexts_per_label)
                    ]
                )
            else:
                final_rotation_params[0][1]["degrees"] = torch.cat(
                    [
                        torch.tensor(
                            (
                                360 * (i / n_contexts)
                                + self.cfg.loss.augmentation_class.angle_offset
                            )
                            % 360
                        ).repeat(n_samples_in_context)
                        for i, n_samples_in_context in enumerate(n_contexts_per_label)
                    ]
                )
            context_aug_x = self.context_augmentations(x, params=final_rotation_params)
        elif self.cfg.loss.augmentation_class.augmentation_type == "custom":
            context_aug_x = self.context_augmentations(
                x, extra_args={"labels": context_labels}
            )
        else:
            raise NotImplementedError(
                f"Augmentation parameter aggregation not yet implemented for {self.cfg.loss.augmentation_class.name}."
            )
        # if self.cfg.augmentation.align_content_augmentations:
        #     x_aug = context_aug_x.chunk(4)
        #     x_aug = torch.cat([x_aug[0], x_aug[2], x_aug[1], x_aug[3]])
        #     context_labels = context_labels.chunk(4)
        #     context_labels = torch.cat([context_labels[0], context_labels[2]])
        # else:
        x_aug = self.augmentations(context_aug_x.repeat(2, 1, 1, 1))
        x_aug = self.normalize(x_aug)
        return x_aug, context_labels

    def train_embed(self, x):
        features = self.encoder(x)
        if not self.cfg.loss.get("hierarchy", False):
            features = list(features[:, None].chunk(2))
            features = torch.cat(features, dim=1)
            return features
        hierarchy_features = []
        for feature_level in features:
            feature_level = list(feature_level[:, None].chunk(2))
            feature_level = torch.cat(feature_level, dim=1)
            hierarchy_features.append(feature_level)
        return hierarchy_features

    def training_step(self, batch, batch_idx=0):
        x = batch[0]
        x, context_labels = self.train_augment(x)

        # Embed augmented samples
        features = self.train_embed(x)

        if self.cfg.loss.get("hierarchy", False):
            loss = []
            for i, feature_level in enumerate(features):
                curr_loss = self.criterion(
                    feature_level,
                    context_labels,
                    epoch=self.current_epoch,
                    logging_prefix=f"train/hierarchy_{i}_",
                )
                self.log(f"train/hierarchy_{i}_loss", curr_loss)
                loss.append(curr_loss)
            loss = torch.stack(loss).mean()
            self.log("train/loss", loss)
            return loss
        loss = self.criterion(features, context_labels, epoch=self.current_epoch)
        self.log("train/loss", loss)
        return loss

    @torch.no_grad()
    def on_validation_epoch_start(self):
        """
        Compute train embeddings before validating
        """
        self.prepare_test_mode()

    @torch.no_grad()
    def validation_loss(self, features):
        return torch.zeros(features.shape[0]).type_as(features)
        # TODO Fix Validation loss computation
        n_context_augs = self.cfg.loss.n_context_augs

        def realign_context(x):
            return torch.cat(
                x[:, :, None].chunk(x.shape[1] // n_context_augs, dim=1), dim=2
            )

        neighbor_idx = []
        # Get nearest neighbors of current features
        for i in range(features.shape[1]):
            curr_features = features[:, i].cpu().contiguous().float()
            neighbor = (
                self.cos_index[i]
                .search(
                    F.normalize(curr_features, dim=-1),
                    self.cfg.loss.val_loss_batch_size - 1,
                )[1]
                .to(self.normal_embeddings.device)
            )
            neighbor_idx.append(neighbor)
        loss = []
        features = realign_context(features)
        train_features = realign_context(
            self.normal_embeddings  # TODO: make this a config
        )
        for i in range(features.shape[2] - 1):
            curr_loss = []
            curr_features = features[:, :, i : i + 2]
            curr_train_features = train_features[neighbor_idx[i], :, i : i + 2].type_as(
                curr_features
            )

            for sample_idx in range(curr_features.shape[0]):
                feature = curr_features[sample_idx, None]
                score_features = torch.cat(
                    [
                        torch.cat(
                            [feature[:, j], curr_train_features[sample_idx, :, j]]
                        )
                        for j in range(feature.shape[1])
                    ]
                )
                context_labels = torch.arange(n_context_augs).repeat_interleave(
                    score_features.shape[0] // n_context_augs
                )
                curr_loss.append(
                    self.criterion(
                        score_features,
                        context_labels,
                        logging_prefix=None,
                    ).cpu()
                )
            loss.append(torch.tensor(curr_loss))
        loss = torch.stack(loss).mean(dim=0)
        return loss

    def validation_step(self, batch, batch_idx=0):
        x, anomaly_label = batch

        # Get test time augmentations to compute scores
        if self.cfg.augmentation.test_time_augmentations is not None:
            samples, projections, features = self.get_test_augmentation_embeddings(x)
        else:
            projections, features = self(
                x,
                context_labels=torch.randint(
                    0, self.cfg.loss.n_context_augs, (x.shape[0],)
                ),
            )

        # Compute validation loss
        loss = self.validation_loss(features)
        for label, losses in zip(
            [0, 1], [self.validation_loss_normal, self.validation_loss_anomaly]
        ):
            # Embed and augment samples
            if (anomaly_label == label).long().sum() > 1.0:
                losses += loss[anomaly_label.cpu() == label]

        # Compute and store anomaly scores for later logging
        self.validation_gts.append(anomaly_label)
        for inference_type in self.cfg.loss.utils.inference_types():
            anomaly_predictions = self.compute_anomaly_score(
                x, features, projections, prediction_mode=inference_type
            )
            self.validation_predictions[inference_type].append(anomaly_predictions)

    def on_validation_epoch_end(self):
        gts = torch.cat(self.validation_gts).cpu()
        for inference_type in self.cfg.loss.utils.inference_types():
            # Compute threshold agnostic metrics
            predictions = (
                torch.cat(self.validation_predictions[inference_type]).cpu().float()
            )
            inference_type = inference_type.value
            self.validation_log_metrics(gts, predictions, name=inference_type)
        if len(self.validation_loss_normal) > 0:
            loss_normal = torch.tensor(self.validation_loss_normal).mean()
            self.log(f"validation_loss/normal", loss_normal)
        if len(self.validation_loss_anomaly) > 0:
            loss_anomaly = torch.tensor(self.validation_loss_anomaly).mean()
            self.log(f"validation_loss/anomaly", loss_anomaly)
        # Reset validation loop output storage
        self.configure_validation_step_outputs()

    def validation_log_metrics(self, gts, predictions, name=""):
        # Comput ROC curve
        fprs, tprs, thresholds = roc_curve(gts, predictions.float())
        if self.cfg.logging.log_roc_curve:
            f, ax = plt.subplots()
            ax.plot(fprs, tprs)
            img_buf = io.BytesIO()
            f.savefig(img_buf, format="png")
            im = Image.open(img_buf)
            self.logger.log_image(key=f"validation_{name}/roc_curve", images=[im])
            plt.close("all")
        # Compute fpr-95
        fpr_95, threshold_95 = next(
            (fprs[i], thresholds[i]) for i, tpr in enumerate(tprs) if tpr >= 0.95
        )
        self.log(f"validation_{name}/fpr-95", fpr_95)
        auc = roc_auc_score(gts, predictions)
        self.log(f"validation_{name}/auroc", auc)
        avg_precision = average_precision_score(gts, predictions)
        self.log(f"validation_{name}/average_precision", avg_precision)

    def train_dataloader(self, drop_last=True, augment=True, batch_size=None):
        # any iterable or collection of iterables
        dataset = self.train_dataset
        if not augment:
            dataset = self.train_dataset_no_aug
        if batch_size is None:
            batch_size = self.cfg.batch_size
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            pin_memory=self.device == "cuda",
            persistent_workers=True,
            drop_last=drop_last,
        )
        return dataloader

    def normal_embedding_dataloader(self):
        dataset = TensorDataset(
            self.normal_embeddings,
        )
        return DataLoader(
            dataset,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            pin_memory=self.device == "cuda",
            drop_last=False,
        )

    def val_dataloader(self, shuffle=False):
        return DataLoader(
            self.validation_dataset,
            batch_size=self.cfg.batch_size_eval,
            shuffle=shuffle,
            num_workers=self.cfg.num_workers,
            pin_memory=self.device == "cuda",
            drop_last=False,
        )

    # def test_dataloader(self):
    #     # any iterable or collection of iterables
    #     return DataLoader(self.test_dataset)

    @torch.no_grad()
    def compute_augmentation_distance_score(self, x):
        # Create content and context augmentations
        num_perms = 10
        num_simclr_augs = 4
        content_augmentation = K.ImageSequential(
            *get_augmentation(
                self.cfg.augmentation,
                size=self.cfg.dataset.size,
                eval=True,
            ),
            same_on_batch=True,
        )
        content_augmentation = K.ImageSequential(
            *get_augmentation(
                self.cfg.loss.augmentation_class,
                size=self.cfg.dataset.size,
            ),
            same_on_batch=(
                True
                if self.cfg.loss.augmentation_class.name != "random_fisheye"
                else None
            ),
        )
        # Draw fixed content and content augmentations
        random.seed(self.cfg.seed)
        torch.manual_seed(self.cfg.seed)
        content_params = []
        for _ in range(num_perms):
            found_new_params = False
            while not found_new_params:
                new_params = content_augmentation.forward_parameters(x.shape)
                found_new_params = True
                for params in content_params:
                    if equal_augmentation_params(params, new_params):
                        found_new_params = False
            content_params.append(new_params)
        content_params = [
            content_augmentation.forward_parameters(x.shape)
            for _ in range(num_simclr_augs)
        ]
        augmentation_order = torch.rand(num_simclr_augs)
        # Create matrix with embeddings of content/content augmented sample projections
        projections = []
        for i in range(num_perms):
            x_augs = []
            x_augs.append(content_augmentation(x, params=content_params[i]))
            for j in range(num_simclr_augs):
                if augmentation_order[j] < 0.5:
                    x_aug = content_augmentation(x, params=content_params[j])
                    x_aug = content_augmentation(x_aug, params=content_params[i])
                else:
                    x_aug = content_augmentation(x, params=content_params[i])
                    x_aug = content_augmentation(x_aug, params=content_params[j])
                x_augs.append(x_aug)
            x_augs = torch.cat(x_augs)
            x_augs = self.normalize(x_augs)
            z_content = self.encoder(x_augs)
            content_projection = self.criterion.content_loss.get_projection(z_content)
            content_projection = torch.cat(
                content_projection.chunk(num_simclr_augs + 1),
                dim=1,
            )
            projections.append(content_projection[:, None])
        unnormalized_projections = torch.cat(projections, dim=1)
        if self.cfg.loss.similarity_metric == "cos":
            projections = F.normalize(unnormalized_projections, dim=-1)
        else:
            projections = unnormalized_projections
        # Compute anomaly score based on within content and negative between content projection distances
        score = []
        for i in range(num_perms):
            norm_content_projection = projections[:, i]
            if self.cfg.loss.similarity_metric == "cos":
                similarity_matrix = (
                    norm_content_projection[:, 0, None] * norm_content_projection
                ).sum(dim=-1)
                dist_to_same_content_content = (1 + similarity_matrix).mean(dim=-1)
            else:
                similarity_matrix = (
                    norm_content_projection[:, 0, None] - norm_content_projection[:, 1:]
                ).norm(dim=-1)
                dist_to_same_content_content = (1.0 / similarity_matrix).mean(dim=-1)

            similarity_matrix = similarity_matrix[:, :, None]

            dist_to_other_contents = []
            for j in range(num_perms):
                if i == j:
                    continue
                other_norm_content_projection = projections[:, j]
                if self.cfg.loss.similarity_metric == "cos":
                    other_similarity_matrix = (
                        norm_content_projection[:, 0, None]
                        * other_norm_content_projection
                    ).sum(dim=-1)
                    dist_to_other_contents.append(
                        (1 - other_similarity_matrix).mean(dim=-1)
                    )
                else:
                    other_similarity_matrix = (
                        norm_content_projection[:, 0, None]
                        - other_norm_content_projection
                    ).norm(dim=-1)
                    dist_to_other_contents.append(
                        (other_similarity_matrix).mean(dim=-1)
                    )

                other_similarity_matrix = other_similarity_matrix[:, None].repeat(
                    1, num_simclr_augs + 1, 1
                )
                similarity_matrix = torch.cat(
                    [similarity_matrix, other_similarity_matrix], dim=-1
                )

            logits = (
                similarity_matrix.reshape(-1, similarity_matrix.shape[-1])
                / self.criterion.content_loss.temperature
            )
            labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
            clf_score = F.cross_entropy(logits, labels, reduction="none").reshape(
                -1, num_simclr_augs + 1
            )
            norm_score = 1.0 / (1 + unnormalized_projections[:, i].norm(dim=-1))
            curr_score = (clf_score * norm_score).mean(dim=-1)
            score.append(curr_score)

            # dist_to_other_contents = torch.stack(dist_to_other_contents).mean(dim=0)
            # score.append(-dist_to_other_contents * dist_to_same_content_content)
        score = torch.stack(score).mean(dim=0)
        return score

    @torch.no_grad()
    def compute_anomaly_score(
        self, samples, features, projections, prediction_mode, projection_type="content"
    ):
        n_context_augs = self.cfg.loss.n_context_augs
        score = []
        if prediction_mode == AnomalyScore.augmentation_distance:
            return self.compute_augmentation_distance_score(samples)
        if prediction_mode == AnomalyScore.loss_score:
            return self.validation_loss(features)
        for curr_nn_idx in range(features.shape[1]):
            curr_content_projections = (
                projections["content"][:, curr_nn_idx].float().cpu().contiguous()
            )
            curr_context_projections = (
                projections["context"][:, curr_nn_idx].float().cpu().contiguous()
            )
            curr_features = features[:, curr_nn_idx].float().cpu().contiguous()
            n_neighbors = 1
            if prediction_mode == AnomalyScore.nearest_neighbor:
                D, _ = self.l2_index[curr_nn_idx].search(curr_features, n_neighbors)
                score.append(D.sum(dim=-1))
            elif prediction_mode in [
                AnomalyScore.nearest_cosine_neighbor,
                AnomalyScore.nearest_cosine_neighbor_and_content_prediction,
            ]:
                D, _ = self.cos_index[curr_nn_idx].search(
                    F.normalize(curr_features, dim=-1), n_neighbors
                )
                score.append(-D.sum(dim=-1))
            elif prediction_mode == AnomalyScore.nearest_neighbor_projection:
                D, _ = self.l2_context_projections_index[curr_nn_idx].search(
                    curr_context_projections, n_neighbors
                )
                score.append(D.sum(dim=-1))
            elif prediction_mode == AnomalyScore.nearest_cosine_neighbor_projection:
                D, _ = self.cos_context_projections_index[curr_nn_idx].search(
                    F.normalize(curr_context_projections, dim=-1), n_neighbors
                )
                score.append(-D.sum(dim=-1))
            elif prediction_mode == AnomalyScore.norm_weighted_nearest_cosine_neighbor:
                D, _ = self.cos_index[curr_nn_idx].search(curr_features, n_neighbors)
                score.append(-D.sum(dim=-1))
            elif (
                prediction_mode
                == AnomalyScore.norm_weighted_nearest_cosine_neighbor_projection
            ):
                D, _ = self.cos_context_projections_index[curr_nn_idx].search(
                    curr_context_projections, n_neighbors
                )
                score.append(-D.sum(dim=-1))
            elif prediction_mode == AnomalyScore.context_likelihood:
                likelihood = -torch.from_numpy(
                    self.context_dists[curr_nn_idx].logpdf(  # score_samples(
                        F.normalize(curr_features, dim=-1)
                    )
                )
                # Handle ill-conditioning of the covariance matrix at beginning of training
                likelihood[torch.isinf(likelihood)] = 1e-6
                score.append(likelihood)
            else:
                raise NotImplementedError()
        if (
            prediction_mode
            != AnomalyScore.nearest_cosine_neighbor_and_content_prediction
        ):
            return torch.stack(score).mean(dim=0)
        else:
            nearest_cosine_neighbor_score = torch.stack(score).mean(dim=0)
            normalized_content_proj = F.normalize(projections["content"], dim=-1)
            content_prediction_matrix = (
                normalized_content_proj[:, :, None] * normalized_content_proj[:, None]
            ).sum(dim=-1)
            mask = torch.eye(normalized_content_proj.shape[1], dtype=torch.bool).to(
                self.device
            )
            content_prediction_matrix = content_prediction_matrix[:, ~mask].view(
                content_prediction_matrix.shape[0],
                -1,
            )
            content_prediction_score = content_prediction_matrix.mean(dim=-1).to(
                nearest_cosine_neighbor_score.device
            )
            return nearest_cosine_neighbor_score + content_prediction_score

    @torch.no_grad()
    def compute_normal_embeddings(self, init=False):
        is_training = self.training
        if init:
            self.train()
        else:
            self.eval()
        normal_embeddings = []
        normal_projections = {}
        for projection_type in ["content", "context"]:
            normal_projections[projection_type] = []
        for batch in tqdm(
            self.train_dataloader(
                drop_last=False, augment=False, batch_size=self.cfg.batch_size_eval
            ),
            leave=False,
        ):
            x = batch[0].to(self.device)
            if self.cfg.augmentation.test_time_augmentations is not None:
                _, projection, features = self.get_test_augmentation_embeddings(
                    x,
                    identity_augmentation=self.cfg.augmentation.knn_embeddings_with_identity,
                )
            else:
                projection, features = self(
                    x,
                    context_labels=torch.randint(
                        0, self.cfg.loss.n_context_augs, (x.shape[0],)
                    ),
                )
            for projection_type in ["content", "context"]:
                normal_projections[projection_type].append(
                    projection[projection_type].cpu()
                )
            normal_embeddings.append(features.cpu())
        self.normal_embeddings = torch.cat(normal_embeddings)
        n_context_augs = self.cfg.loss.n_context_augs
        self.normal_projections = {}
        for projection_type in ["content", "context"]:
            self.normal_projections[projection_type] = torch.cat(
                normal_projections[projection_type]
            )
            # projection_per_context = []
            # for i in range(self.normal_projections[projection_type].shape[1]):
            #     if i < n_context_augs:
            #         projection_per_context.append(
            #             self.normal_projections[projection_type][:, i, None]
            #         )
            #     else:
            #         projection_per_context[i % n_context_augs] = torch.cat(
            #             [
            #                 projection_per_context[i % n_context_augs],
            #                 self.normal_projections[projection_type][:, i, None],
            #             ],
            #             dim=1,
            #         )
            # self.normal_projections[projection_type] = torch.cat(
            #     projection_per_context, dim=1
            # )

        if is_training:
            self.train()

    @torch.no_grad()
    def compute_context_distance_ratios(self, print_results=False):
        n_contexts = self.cfg.loss.n_context_augs
        projections = F.normalize(self.normal_projections["context"], dim=-1)
        n_contexts_logging = min(projections.shape[1], n_contexts)
        mean_distances = {
            "in_context": [[] for _ in range(n_contexts_logging)],
            "between_context": [[] for _ in range(n_contexts_logging)],
        }
        for i in range(projections.shape[1]):
            for j in range(i, projections.shape[1]):
                idx1 = torch.randperm(projections.shape[0])[:1000]
                idx2 = torch.randperm(projections.shape[0])[:1000]
                if self.cfg.loss.similarity_metric == "cos":
                    distances = (
                        torch.matmul(projections[idx1, i], projections[idx2, j].T) + 1
                    )
                elif self.cfg.loss.similarity_metric == "mse":
                    distances = 1.0 / (
                        torch.cdist(projections[:, i], projections[:, j]) + 1.0
                    )
                if i % n_contexts == j % n_contexts:
                    mean_distances["in_context"][i % n_contexts].append(
                        distances.mean(dim=-1)
                    )
                    mean_distances["in_context"][i % n_contexts].append(
                        distances.mean(dim=0)
                    )
                else:
                    mean_distances["between_context"][i % n_contexts].append(
                        distances.mean(dim=-1)
                    )
                    mean_distances["between_context"][j % n_contexts].append(
                        distances.mean(dim=0)
                    )
        mean_distances["in_context"] = [
            torch.cat(d).mean() for d in mean_distances["in_context"]
        ]
        if n_contexts_logging == 1:
            self.log(
                f"context_distance_ratios/in_context_distance",
                mean_distances["in_context"][0],
            )
            if print_results:
                print(
                    f"context_distance_ratios/in_context_distance",
                    mean_distances["in_context"][0],
                )
            return
        mean_distances["between_context"] = [
            torch.cat(d).mean() for d in mean_distances["between_context"]
        ]
        for i in range(n_contexts_logging):
            self.log(
                f"context_distance_ratios/in_to_out_context{i}_ratio",
                mean_distances["in_context"][i] / mean_distances["between_context"][i],
            )
            if print_results:
                print(
                    f"context_distance_ratios/in_to_out_context{i}_ratio",
                    mean_distances["in_context"][i]
                    / mean_distances["between_context"][i],
                )

    def configure_validation_step_outputs(self):
        self.validation_predictions = {}
        self.validation_score_histogram = {}
        for inference_type in self.cfg.loss.utils.inference_types():
            self.validation_predictions[inference_type] = []
            self.validation_score_histogram[inference_type] = []

        self.validation_gts = []
        self.validation_loss_normal = []
        self.validation_loss_anomaly = []

    def configure_model(self):
        # Initialize small epsilone for numerical stability
        self.eps = 1e-6
        if self.cfg.model.name == "resnet":
            self.encoder = ResNetEncoder(
                self.cfg.model,
                in_size=self.cfg.dataset.size,
                hierarchies=self.cfg.loss.get("hierarchy", False),
            )
        else:
            raise NotImplementedError(f"{self.cfg.model.name} encoder not implemented.")
        self.criterion = self.cfg.loss.utils.get_loss(
            backbone_out_dim=self.encoder.out_dim,
            **OmegaConf.to_container(self.cfg.loss),
        )
        # self.register_buffer("center", torch.zeros(self.encoder.out_dim))

    def configure_datasets(self):
        # Create train and test sets
        (
            self.train_dataset,
            self.train_dataset_no_aug,
            self.validation_dataset,
        ) = get_datasets(
            self.cfg.dataset,
        )
        if self.cfg.model.pretrain:
            self.normalization_mean = [0.485, 0.456, 0.406]
            self.normalization_std = [0.229, 0.224, 0.225]
        else:
            self.normalization_mean = [0.5, 0.5, 0.5]
            self.normalization_std = [0.5, 0.5, 0.5]
        self.normalize = K.Normalize(
            mean=self.normalization_mean, std=self.normalization_std
        )
        self.denormalize = K.Denormalize(
            mean=self.normalization_mean, std=self.normalization_std
        )

        # Create training augmentations for contrastive learning
        augmentations = get_augmentation(
            self.cfg.augmentation,
            size=self.cfg.dataset.size,
        )
        self.augmentations = K.ImageSequential(*augmentations)

        # Create test time SimCLR augmentations
        test_augmentation = get_augmentation(
            self.cfg.augmentation, size=self.cfg.dataset.size, eval=True
        )
        self.test_augmentation = K.ImageSequential(
            *test_augmentation,
            same_on_batch=True,
        )

        # Create context augmentations for training
        low_frequency_samples = None
        if (
            self.cfg.loss.augmentation_class.name
            == "low_frequency_context_augmentation"
        ):
            n_augs = self.cfg.loss.n_context_augs
            train_generator = iter(
                self.train_dataloader(drop_last=False, augment=False)
            )
            low_frequency_samples = next(train_generator)[0][:n_augs]
            while low_frequency_samples.shape[0] < n_augs:
                low_frequency_samples = torch.cat(
                    [
                        low_frequency_samples,
                        next(train_generator)[0][
                            : n_augs - low_frequency_samples.shape[0]
                        ],
                    ]
                )

        context_augmentations = get_augmentation(
            self.cfg.loss.augmentation_class,
            low_frequency_samples=low_frequency_samples,
            size=self.cfg.dataset.size,
            n_context_augs=self.cfg.loss.n_context_augs,
        )
        self.context_augmentations = K.ImageSequential(
            *context_augmentations, same_on_batch=True
        )

        # Create context test time augmentations
        test_context_augmentation = get_augmentation(
            self.cfg.loss.augmentation_class,
            low_frequency_samples=low_frequency_samples,
            size=self.cfg.dataset.size,
            eval=True,
            n_context_augs=self.cfg.loss.n_context_augs,
        )
        self.test_context_augmentation = K.ImageSequential(
            *test_context_augmentation,
            same_on_batch=True,
        )

    def configure_optimizers(self):
        opt = create_optimizer(self.cfg.optimizer, self.parameters())
        if self.cfg.optimizer.scheduler is None:
            return opt
        scheduler = create_scheduler(
            self.cfg.optimizer.scheduler, opt, self.train_dataloader(), self.cfg.epochs
        )
        return {
            "optimizer": opt,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": (
                    "epoch" if self.cfg.optimizer.scheduler != "onecycle" else "step"
                ),
            },
        }
