import io
import random

import faiss
import faiss.contrib.torch_utils
import kornia.augmentation as K
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import scipy
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from PIL import Image
from scipy.stats import multivariate_normal
from sklearn.metrics import (
    average_precision_score,
    roc_auc_score,
    roc_curve,
)
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm

from configs.enums import *
from data.data import get_datasets
from data.transforms import *
from models.encoder import ResNetEncoder
from utils.optimizer import *


class ContextContrastingADModel(pl.LightningModule):
    def __init__(self, cfg, data_path=None):
        super().__init__()
        self.save_hyperparameters()
        if data_path is not None:
            cfg.dataset.data_path = data_path
        self.cfg = cfg
        self.configure_datasets()
        self.configure_model()
        self.configure_validation_step_outputs()

    @torch.no_grad()
    def get_test_augmentation_embeddings(self, x, identity_augmentation=False):
        n_augs = self.cfg.augmentation.test_time_augmentations
        n_context_augs = self.cfg.loss.n_context_augs
        # Fix seed for test time augmentations
        random.seed(self.cfg.seed)
        torch.manual_seed(self.cfg.seed)

        # Create n_augs many different context augmentation parameters
        context_aug_params = []
        if self.cfg.loss.augmentation_class.augmentation_type == "kornia":
            for _ in range(n_augs):
                found_new_params = False
                alarm_counter = 1000
                while not found_new_params:
                    alarm_counter -= 1
                    new_params = self.test_context_augmentation.forward_parameters(
                        x.shape
                    )
                    found_new_params = True
                    for params in context_aug_params[-(n_context_augs - 1) :]:
                        if (
                            any_equal_augmentation_params(params, new_params)
                            and not n_context_augs == 1
                        ):
                            found_new_params = False
                    if alarm_counter < 0 and not found_new_params:
                        raise ValueError(
                            "Could not find new context augmentation parameters, please decrease the number of test time augmentations or increase the number of context augmentations."
                        )
                context_aug_params.append(new_params)
        # Augment samples with simclr augmentations and context augmentations
        augmented_xs = []
        context_labels = []
        for i in range(n_augs):
            labels = torch.ones(x.shape[0]) * (i % n_context_augs)
            if self.cfg.loss.augmentation_class.augmentation_type == "kornia":
                augmented_x = self.test_context_augmentation(
                    x, params=context_aug_params[i]
                )
            elif self.cfg.loss.augmentation_class.augmentation_type == "custom":
                augmented_x = self.test_context_augmentation(
                    x, extra_args={"labels": labels}
                )
            if not identity_augmentation:
                augmented_x = self.test_augmentation(augmented_x)
            augmented_x = self.normalize(augmented_x)
            augmented_xs.append(augmented_x)
            context_labels.append(labels)
        augmented_xs = torch.cat(augmented_xs)
        context_labels = torch.cat(context_labels)

        # Get embeddings of augmented features
        projection, features = self(augmented_xs, context_labels)
        for projection_type in ["content", "context"]:
            projection[projection_type] = projection[projection_type][:, None].chunk(
                n_augs
            )
            projection[projection_type] = torch.cat(projection[projection_type], dim=1)
        features = features.chunk(n_augs)
        features = torch.cat(features, dim=1)
        return (
            augmented_xs.chunk(n_augs),
            projection,
            features,
        )

    def forward(self, x, context_labels):
        z = self.encoder(x)
        if self.cfg.loss.get("hierarchy", False):
            z = z[0]
        projection = self.criterion.get_projection(z, context_labels)
        return projection, z[:, None]

    def create_index(self, embeddings, index_type="l2"):
        d = embeddings.shape[-1]
        if self.cfg.get("index_pca_dim", None) is not None:
            d_orig = d
            d = self.cfg.index_pca_dim
        if index_type == "l2":
            index = faiss.IndexFlatL2(d)
        elif index_type == "cos":
            index = faiss.IndexFlatIP(d)
            embeddings = F.normalize(embeddings, dim=-1)
        if self.cfg.approximate_index:
            index = faiss.IndexIVFFlat(index, d, 100)
            index.train(embeddings.contiguous())
        if self.cfg.get("index_pca_dim", None) is not None:
            pca = faiss.PCAMatrix(d_orig, self.cfg.index_pca_dim)
            index = faiss.IndexPreTransform(pca, index)
            index.train(embeddings.contiguous())
        index.add(embeddings.contiguous())
        return index

    @torch.no_grad()
    def prepare_test_mode(self, init=False, print_results=False):
        # Compute embeddings of training data without augmentation
        self.compute_normal_embeddings(init=init)
        # Fit GMM to normal embeddings
        self.context_dists = []
        for i in tqdm(range(self.normal_embeddings.shape[1]), leave=False):
            curr_features = F.normalize(self.normal_embeddings[:, i], dim=-1)
            try:
                dist = multivariate_normal(
                    mean=curr_features.mean(dim=0),
                    cov=curr_features.T.cov(),
                    allow_singular=True,
                )
            except ValueError:
                # Covariance matrix is sometimes singular early in training, use diagonal covariance matrix instead
                dist = multivariate_normal(
                    mean=curr_features.mean(dim=0),
                    cov=scipy.stats.Covariance.from_diagonal(curr_features.std(dim=0)),
                )
            self.context_dists.append(
                dist
                # LocalOutlierFactor(
                #     # metric="cosine",
                #     n_neighbors=40,
                #     novelty=True,
                #     contamination=1e-3,
                #     algorithm="kd_tree",
                #     n_jobs=self.cfg.num_workers,
                # ).fit(curr_features)
            )

        # Log ratio between in-context and out-of-context distances of training samples
        self.compute_context_distance_ratios(print_results=print_results)

        # Update search index for nearest neighbor computations
        self.l2_index, self.cos_index = [], []
        self.l2_content_projections_index, self.cos_content_projections_index = [], []
        self.l2_context_projections_index, self.cos_context_projections_index = [], []
        # Create one index for every test time augmentation
        for i in tqdm(
            range(self.normal_embeddings.shape[1]),
            desc="Creating search indices",
            leave=False,
        ):
            self.l2_index.append(
                self.create_index(
                    self.normal_embeddings[:, i],
                    index_type="l2",
                )
            )
            self.cos_index.append(
                self.create_index(
                    self.normal_embeddings[:, i],
                    index_type="cos",
                )
            )
            normal_context_projections = self.normal_projections["context"][:, i]
            self.l2_context_projections_index.append(
                self.create_index(
                    normal_context_projections,
                    index_type="l2",
                )
            )
            self.cos_context_projections_index.append(
                self.create_index(
                    normal_context_projections,
                    index_type="cos",
                )
            )
            normal_content_projections = self.normal_projections["content"][
                :, i
            ].contiguous()
            self.l2_content_projections_index.append(
                self.create_index(
                    normal_content_projections,
                    index_type="l2",
                )
            )
            self.cos_content_projections_index.append(
                self.create_index(
                    normal_content_projections,
                    index_type="cos",
                )
            )

    @torch.no_grad()
    def on_fit_start(self):
        return

    @torch.no_grad()
    def train_augment(self, x):
        # Duplicate x to ensure each sample is in different contexts
        n = x.shape[0]
        x = x.repeat(2, 1, 1, 1)
        # if self.cfg.augmentation.align_content_augmentations:
        #     x = self.augmentations(x).repeat(2, 1, 1, 1)
        #     n = 2 * n

        # Generate context augmentation parameters
        n_contexts = self.cfg.loss.n_context_augs
        n_contexts_per_label = scipy.stats.multivariate_hypergeom.rvs(
            [n for _ in range(n_contexts)], 2 * n
        )
        # Shuffle context labels to ensure there is no bias for content contrasting
        labels = torch.randperm(n_contexts)
        context_labels = [
            labels[i] * torch.ones(n_samples_in_context)
            for i, n_samples_in_context in enumerate(n_contexts_per_label)
        ]
        context_labels = torch.cat(context_labels)
        # Aggregate sampled transformation parameters and apply to x
        if self.cfg.loss.augmentation_class.name == "random_rotation":
            final_rotation_params = self.context_augmentations.forward_parameters(
                (2 * n, x.shape[1], x.shape[2], x.shape[3])
            )
            if self.cfg.loss.augmentation_class.max_180_degrees and n_contexts != 1:
                final_rotation_params[0][1]["degrees"] = torch.cat(
                    [
                        torch.tensor(i / (n_contexts - 1)).repeat(n_samples_in_context)
                        * 180
                        for i, n_samples_in_context in enumerate(n_contexts_per_label)
                    ]
                )
            else:
                final_rotation_params[0][1]["degrees"] = torch.cat(
                    [
                        torch.tensor(
                            (
                                360 * (i / n_contexts)
                                + self.cfg.loss.augmentation_class.angle_offset
                            )
                            % 360
                        ).repeat(n_samples_in_context)
                        for i, n_samples_in_context in enumerate(n_contexts_per_label)
                    ]
                )
            context_aug_x = self.context_augmentations(x, params=final_rotation_params)
        elif self.cfg.loss.augmentation_class.augmentation_type == "custom":
            context_aug_x = self.context_augmentations(
                x, extra_args={"labels": context_labels}
            )
        else:
            raise NotImplementedError(
                f"Augmentation parameter aggregation not yet implemented for {self.cfg.loss.augmentation_class.name}."
            )
        # if self.cfg.augmentation.align_content_augmentations:
        #     x_aug = context_aug_x.chunk(4)
        #     x_aug = torch.cat([x_aug[0], x_aug[2], x_aug[1], x_aug[3]])
        #     context_labels = context_labels.chunk(4)
        #     context_labels = torch.cat([context_labels[0], context_labels[2]])
        # else:
        x_aug = self.augmentations(context_aug_x.repeat(2, 1, 1, 1))
        x_aug = self.normalize(x_aug)
        return x_aug, context_labels

    def train_embed(self, x):
        features = self.encoder(x)
        if not self.cfg.loss.get("hierarchy", False):
            features = list(features[:, None].chunk(2))
            features = torch.cat(features, dim=1)
            return features
        hierarchy_features = []
        for feature_level in features:
            feature_level = list(feature_level[:, None].chunk(2))
            feature_level = torch.cat(feature_level, dim=1)
            hierarchy_features.append(feature_level)
        return hierarchy_features

    def training_step(self, batch, batch_idx=0):
        x = batch[0]
        x, context_labels = self.train_augment(x)

        # Embed augmented samples
        features = self.train_embed(x)

        if self.cfg.loss.get("hierarchy", False):
            loss = []
            for i, feature_level in enumerate(features):
                curr_loss = self.criterion(
                    feature_level,
                    context_labels,
                    epoch=self.current_epoch,
                    logging_prefix=f"train/hierarchy_{i}_",
                )
                self.log(f"train/hierarchy_{i}_loss", curr_loss)
                loss.append(curr_loss)
            loss = torch.stack(loss).mean()
            self.log("train/loss", loss)
            return loss
        loss = self.criterion(features, context_labels, epoch=self.current_epoch)
        self.log("train/loss", loss)
        return loss

    @torch.no_grad()
    def on_validation_epoch_start(self):
        """
        Compute train embeddings before validating
        """
        self.prepare_test_mode()

    @torch.no_grad()
    def validation_loss(self, features):
        return torch.zeros(features.shape[0]).type_as(features)
        # TODO Fix Validation loss computation
        n_context_augs = self.cfg.loss.n_context_augs

        def realign_context(x):
            return torch.cat(
                x[:, :, None].chunk(x.shape[1] // n_context_augs, dim=1), dim=2
            )

        neighbor_idx = []
        # Get nearest neighbors of current features
        for i in range(features.shape[1]):
            curr_features = features[:, i].cpu().contiguous().float()
            neighbor = (
                self.cos_index[i]
                .search(
                    F.normalize(curr_features, dim=-1),
                    self.cfg.loss.val_loss_batch_size - 1,
                )[1]
                .to(self.normal_embeddings.device)
            )
            neighbor_idx.append(neighbor)
        loss = []
        features = realign_context(features)
        train_features = realign_context(
            self.normal_embeddings  # TODO: make this a config
        )
        for i in range(features.shape[2] - 1):
            curr_loss = []
            curr_features = features[:, :, i : i + 2]
            curr_train_features = train_features[neighbor_idx[i], :, i : i + 2].type_as(
                curr_features
            )

            for sample_idx in range(curr_features.shape[0]):
                feature = curr_features[sample_idx, None]
                score_features = torch.cat(
                    [
                        torch.cat(
                            [feature[:, j], curr_train_features[sample_idx, :, j]]
                        )
                        for j in range(feature.shape[1])
                    ]
                )
                context_labels = torch.arange(n_context_augs).repeat_interleave(
                    score_features.shape[0] // n_context_augs
                )
                curr_loss.append(
                    self.criterion(
                        score_features,
                        context_labels,
                        logging_prefix=None,
                    ).cpu()
                )
            loss.append(torch.tensor(curr_loss))
        loss = torch.stack(loss).mean(dim=0)
        return loss

    def validation_step(self, batch, batch_idx=0):
        x, anomaly_label = batch

        # Get test time augmentations to compute scores
        if self.cfg.augmentation.test_time_augmentations is not None:
            samples, projections, features = self.get_test_augmentation_embeddings(x)
        else:
            projections, features = self(
                x,
                context_labels=torch.randint(
                    0, self.cfg.loss.n_context_augs, (x.shape[0],)
                ),
            )

        # Compute validation loss
        loss = self.validation_loss(features)
        for label, losses in zip(
            [0, 1], [self.validation_loss_normal, self.validation_loss_anomaly]
        ):
            # Embed and augment samples
            if (anomaly_label == label).long().sum() > 1.0:
                losses += loss[anomaly_label.cpu() == label]

        # Compute and store anomaly scores for later logging
        self.validation_gts.append(anomaly_label)
        for inference_type in self.cfg.loss.utils.inference_types():
            anomaly_predictions = self.compute_anomaly_score(
                x, features, projections, prediction_mode=inference_type
            )
            self.validation_predictions[inference_type].append(anomaly_predictions)

    def on_validation_epoch_end(self):
        gts = torch.cat(self.validation_gts).cpu()
        for inference_type in self.cfg.loss.utils.inference_types():
            # Compute threshold agnostic metrics
            predictions = (
                torch.cat(self.validation_predictions[inference_type]).cpu().float()
            )
            inference_type = inference_type.value
            self.validation_log_metrics(gts, predictions, name=inference_type)
        if len(self.validation_loss_normal) > 0:
            loss_normal = torch.tensor(self.validation_loss_normal).mean()
            self.log("validation_loss/normal", loss_normal)
        if len(self.validation_loss_anomaly) > 0:
            loss_anomaly = torch.tensor(self.validation_loss_anomaly).mean()
            self.log("validation_loss/anomaly", loss_anomaly)
        # Reset validation loop output storage
        self.configure_validation_step_outputs()

    def validation_log_metrics(self, gts, predictions, name=""):
        # Comput ROC curve
        fprs, tprs, thresholds = roc_curve(gts, predictions.float())
        if self.cfg.logging.log_roc_curve:
            f, ax = plt.subplots()
            ax.plot(fprs, tprs)
            img_buf = io.BytesIO()
            f.savefig(img_buf, format="png")
            im = Image.open(img_buf)
            self.logger.log_image(key=f"validation_{name}/roc_curve", images=[im])
            plt.close("all")
        # Compute fpr-95
        fpr_95, threshold_95 = next(
            (fprs[i], thresholds[i]) for i, tpr in enumerate(tprs) if tpr >= 0.95
        )
        self.log(f"validation_{name}/fpr-95", fpr_95)
        auc = roc_auc_score(gts, predictions)
        self.log(f"validation_{name}/auroc", auc)
        avg_precision = average_precision_score(gts, predictions)
        self.log(f"validation_{name}/average_precision", avg_precision)

    def train_dataloader(self, drop_last=True, augment=True, batch_size=None):
        # any iterable or collection of iterables
        dataset = self.train_dataset
        if not augment:
            dataset = self.train_dataset_no_aug
        if batch_size is None:
            batch_size = self.cfg.batch_size
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            pin_memory=self.device == "cuda",
            persistent_workers=True,
            drop_last=drop_last,
        )
        return dataloader

    def normal_embedding_dataloader(self):
        dataset = TensorDataset(
            self.normal_embeddings,
        )
        return DataLoader(
            dataset,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            pin_memory=self.device == "cuda",
            drop_last=False,
        )

    def val_dataloader(self, shuffle=False):
        return DataLoader(
            self.validation_dataset,
            batch_size=self.cfg.batch_size_eval,
            shuffle=shuffle,
            num_workers=self.cfg.num_workers,
            pin_memory=self.device == "cuda",
            drop_last=False,
        )

    # def test_dataloader(self):
    #     # any iterable or collection of iterables
    #     return DataLoader(self.test_dataset)

    @torch.no_grad()
    def compute_augmentation_distance_score(self, x):
        # Create content and context augmentations
        num_perms = 10
        num_simclr_augs = 4
        content_augmentation = K.ImageSequential(
            *get_augmentation(
                self.cfg.augmentation,
                size=self.cfg.dataset.size,
                eval=True,
            ),
            same_on_batch=True,
        )
        content_augmentation = K.ImageSequential(
            *get_augmentation(
                self.cfg.loss.augmentation_class,
                size=self.cfg.dataset.size,
            ),
            same_on_batch=(
                True
                if self.cfg.loss.augmentation_class.name != "random_fisheye"
                else None
            ),
        )
        # Draw fixed content and content augmentations
        random.seed(self.cfg.seed)
        torch.manual_seed(self.cfg.seed)
        content_params = []
        for _ in range(num_perms):
            found_new_params = False
            while not found_new_params:
                new_params = content_augmentation.forward_parameters(x.shape)
                found_new_params = True
                for params in content_params:
                    if equal_augmentation_params(params, new_params):
                        found_new_params = False
            content_params.append(new_params)
        content_params = [
            content_augmentation.forward_parameters(x.shape)
            for _ in range(num_simclr_augs)
        ]
        augmentation_order = torch.rand(num_simclr_augs)
        # Create matrix with embeddings of content/content augmented sample projections
        projections = []
        for i in range(num_perms):
            x_augs = []
            x_augs.append(content_augmentation(x, params=content_params[i]))
            for j in range(num_simclr_augs):
                if augmentation_order[j] < 0.5:
                    x_aug = content_augmentation(x, params=content_params[j])
                    x_aug = content_augmentation(x_aug, params=content_params[i])
                else:
                    x_aug = content_augmentation(x, params=content_params[i])
                    x_aug = content_augmentation(x_aug, params=content_params[j])
                x_augs.append(x_aug)
            x_augs = torch.cat(x_augs)
            x_augs = self.normalize(x_augs)
            z_content = self.encoder(x_augs)
            content_projection = self.criterion.content_loss.get_projection(z_content)
            content_projection = torch.cat(
                content_projection.chunk(num_simclr_augs + 1),
                dim=1,
            )
            projections.append(content_projection[:, None])
        unnormalized_projections = torch.cat(projections, dim=1)
        if self.cfg.loss.similarity_metric == "cos":
            projections = F.normalize(unnormalized_projections, dim=-1)
        else:
            projections = unnormalized_projections
        # Compute anomaly score based on within content and negative between content projection distances
        score = []
        for i in range(num_perms):
            norm_content_projection = projections[:, i]
            if self.cfg.loss.similarity_metric == "cos":
                similarity_matrix = (
                    norm_content_projection[:, 0, None] * norm_content_projection
                ).sum(dim=-1)
                dist_to_same_content_content = (1 + similarity_matrix).mean(dim=-1)
            else:
                similarity_matrix = (
                    norm_content_projection[:, 0, None] - norm_content_projection[:, 1:]
                ).norm(dim=-1)
                dist_to_same_content_content = (1.0 / similarity_matrix).mean(dim=-1)

            similarity_matrix = similarity_matrix[:, :, None]

            dist_to_other_contents = []
            for j in range(num_perms):
                if i == j:
                    continue
                other_norm_content_projection = projections[:, j]
                if self.cfg.loss.similarity_metric == "cos":
                    other_similarity_matrix = (
                        norm_content_projection[:, 0, None]
                        * other_norm_content_projection
                    ).sum(dim=-1)
                    dist_to_other_contents.append(
                        (1 - other_similarity_matrix).mean(dim=-1)
                    )
                else:
                    other_similarity_matrix = (
                        norm_content_projection[:, 0, None]
                        - other_norm_content_projection
                    ).norm(dim=-1)
                    dist_to_other_contents.append(
                        (other_similarity_matrix).mean(dim=-1)
                    )

                other_similarity_matrix = other_similarity_matrix[:, None].repeat(
                    1, num_simclr_augs + 1, 1
                )
                similarity_matrix = torch.cat(
                    [similarity_matrix, other_similarity_matrix], dim=-1
                )

            logits = (
                similarity_matrix.reshape(-1, similarity_matrix.shape[-1])
                / self.criterion.content_loss.temperature
            )
            labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
            clf_score = F.cross_entropy(logits, labels, reduction="none").reshape(
                -1, num_simclr_augs + 1
            )
            norm_score = 1.0 / (1 + unnormalized_projections[:, i].norm(dim=-1))
            curr_score = (clf_score * norm_score).mean(dim=-1)
            score.append(curr_score)

            # dist_to_other_contents = torch.stack(dist_to_other_contents).mean(dim=0)
            # score.append(-dist_to_other_contents * dist_to_same_content_content)
        score = torch.stack(score).mean(dim=0)
        return score

    @torch.no_grad()
    def compute_anomaly_score(
        self, samples, features, projections, prediction_mode, projection_type="content"
    ):
        n_context_augs = self.cfg.loss.n_context_augs
        score = []
        if prediction_mode == AnomalyScore.augmentation_distance:
            return self.compute_augmentation_distance_score(samples)
        if prediction_mode == AnomalyScore.loss_score:
            return self.validation_loss(features)
        for curr_nn_idx in range(features.shape[1]):
            curr_content_projections = (
                projections["content"][:, curr_nn_idx].float().cpu().contiguous()
            )
            curr_context_projections = (
                projections["context"][:, curr_nn_idx].float().cpu().contiguous()
            )
            curr_features = features[:, curr_nn_idx].float().cpu().contiguous()
            n_neighbors = 1
            if prediction_mode == AnomalyScore.nearest_neighbor:
                D, _ = self.l2_index[curr_nn_idx].search(curr_features, n_neighbors)
                score.append(D.sum(dim=-1))
            elif prediction_mode in [
                AnomalyScore.nearest_cosine_neighbor,
                AnomalyScore.nearest_cosine_neighbor_and_content_prediction,
            ]:
                D, _ = self.cos_index[curr_nn_idx].search(
                    F.normalize(curr_features, dim=-1), n_neighbors
                )
                score.append(-D.sum(dim=-1))
            elif prediction_mode == AnomalyScore.nearest_neighbor_projection:
                D, _ = self.l2_context_projections_index[curr_nn_idx].search(
                    curr_context_projections, n_neighbors
                )
                score.append(D.sum(dim=-1))
            elif prediction_mode == AnomalyScore.nearest_cosine_neighbor_projection:
                D, _ = self.cos_context_projections_index[curr_nn_idx].search(
                    F.normalize(curr_context_projections, dim=-1), n_neighbors
                )
                score.append(-D.sum(dim=-1))
            elif prediction_mode == AnomalyScore.norm_weighted_nearest_cosine_neighbor:
                D, _ = self.cos_index[curr_nn_idx].search(curr_features, n_neighbors)
                score.append(-D.sum(dim=-1))
            elif (
                prediction_mode
                == AnomalyScore.norm_weighted_nearest_cosine_neighbor_projection
            ):
                D, _ = self.cos_context_projections_index[curr_nn_idx].search(
                    curr_context_projections, n_neighbors
                )
                score.append(-D.sum(dim=-1))
            elif prediction_mode == AnomalyScore.context_likelihood:
                likelihood = -torch.from_numpy(
                    self.context_dists[curr_nn_idx].logpdf(  # score_samples(
                        F.normalize(curr_features, dim=-1)
                    )
                )
                # Handle ill-conditioning of the covariance matrix at beginning of training
                likelihood[torch.isinf(likelihood)] = 1e-6
                score.append(likelihood)
            else:
                raise NotImplementedError()
        if (
            prediction_mode
            != AnomalyScore.nearest_cosine_neighbor_and_content_prediction
        ):
            return torch.stack(score).mean(dim=0)
        else:
            nearest_cosine_neighbor_score = torch.stack(score).mean(dim=0)
            normalized_content_proj = F.normalize(projections["content"], dim=-1)
            content_prediction_matrix = (
                normalized_content_proj[:, :, None] * normalized_content_proj[:, None]
            ).sum(dim=-1)
            mask = torch.eye(normalized_content_proj.shape[1], dtype=torch.bool).to(
                self.device
            )
            content_prediction_matrix = content_prediction_matrix[:, ~mask].view(
                content_prediction_matrix.shape[0],
                -1,
            )
            content_prediction_score = content_prediction_matrix.mean(dim=-1).to(
                nearest_cosine_neighbor_score.device
            )
            return nearest_cosine_neighbor_score + content_prediction_score

    @torch.no_grad()
    def compute_normal_embeddings(self, init=False):
        is_training = self.training
        if init:
            self.train()
        else:
            self.eval()
        normal_embeddings = []
        normal_projections = {}
        for projection_type in ["content", "context"]:
            normal_projections[projection_type] = []
        for batch in tqdm(
            self.train_dataloader(
                drop_last=False, augment=False, batch_size=self.cfg.batch_size_eval
            ),
            leave=False,
        ):
            x = batch[0].to(self.device)
            if self.cfg.augmentation.test_time_augmentations is not None:
                _, projection, features = self.get_test_augmentation_embeddings(
                    x,
                    identity_augmentation=self.cfg.augmentation.knn_embeddings_with_identity,
                )
            else:
                projection, features = self(
                    x,
                    context_labels=torch.randint(
                        0, self.cfg.loss.n_context_augs, (x.shape[0],)
                    ),
                )
            for projection_type in ["content", "context"]:
                normal_projections[projection_type].append(
                    projection[projection_type].cpu()
                )
            normal_embeddings.append(features.cpu())
        self.normal_embeddings = torch.cat(normal_embeddings)
        n_context_augs = self.cfg.loss.n_context_augs
        self.normal_projections = {}
        for projection_type in ["content", "context"]:
            self.normal_projections[projection_type] = torch.cat(
                normal_projections[projection_type]
            )
            # projection_per_context = []
            # for i in range(self.normal_projections[projection_type].shape[1]):
            #     if i < n_context_augs:
            #         projection_per_context.append(
            #             self.normal_projections[projection_type][:, i, None]
            #         )
            #     else:
            #         projection_per_context[i % n_context_augs] = torch.cat(
            #             [
            #                 projection_per_context[i % n_context_augs],
            #                 self.normal_projections[projection_type][:, i, None],
            #             ],
            #             dim=1,
            #         )
            # self.normal_projections[projection_type] = torch.cat(
            #     projection_per_context, dim=1
            # )

        if is_training:
            self.train()

    @torch.no_grad()
    def compute_context_distance_ratios(self, print_results=False):
        n_contexts = self.cfg.loss.n_context_augs
        projections = F.normalize(self.normal_projections["context"], dim=-1)
        n_contexts_logging = min(projections.shape[1], n_contexts)
        mean_distances = {
            "in_context": [[] for _ in range(n_contexts_logging)],
            "between_context": [[] for _ in range(n_contexts_logging)],
        }
        for i in range(projections.shape[1]):
            for j in range(i, projections.shape[1]):
                idx1 = torch.randperm(projections.shape[0])[:1000]
                idx2 = torch.randperm(projections.shape[0])[:1000]
                if self.cfg.loss.similarity_metric == "cos":
                    distances = (
                        torch.matmul(projections[idx1, i], projections[idx2, j].T) + 1
                    )
                elif self.cfg.loss.similarity_metric == "mse":
                    distances = 1.0 / (
                        torch.cdist(projections[:, i], projections[:, j]) + 1.0
                    )
                if i % n_contexts == j % n_contexts:
                    mean_distances["in_context"][i % n_contexts].append(
                        distances.mean(dim=-1)
                    )
                    mean_distances["in_context"][i % n_contexts].append(
                        distances.mean(dim=0)
                    )
                else:
                    mean_distances["between_context"][i % n_contexts].append(
                        distances.mean(dim=-1)
                    )
                    mean_distances["between_context"][j % n_contexts].append(
                        distances.mean(dim=0)
                    )
        mean_distances["in_context"] = [
            torch.cat(d).mean() for d in mean_distances["in_context"]
        ]
        if n_contexts_logging == 1:
            self.log(
                "context_distance_ratios/in_context_distance",
                mean_distances["in_context"][0],
            )
            if print_results:
                print(
                    "context_distance_ratios/in_context_distance",
                    mean_distances["in_context"][0],
                )
            return
        mean_distances["between_context"] = [
            torch.cat(d).mean() for d in mean_distances["between_context"]
        ]
        for i in range(n_contexts_logging):
            self.log(
                f"context_distance_ratios/in_to_out_context{i}_ratio",
                mean_distances["in_context"][i] / mean_distances["between_context"][i],
            )
            if print_results:
                print(
                    f"context_distance_ratios/in_to_out_context{i}_ratio",
                    mean_distances["in_context"][i]
                    / mean_distances["between_context"][i],
                )

    def configure_validation_step_outputs(self):
        self.validation_predictions = {}
        self.validation_score_histogram = {}
        for inference_type in self.cfg.loss.utils.inference_types():
            self.validation_predictions[inference_type] = []
            self.validation_score_histogram[inference_type] = []

        self.validation_gts = []
        self.validation_loss_normal = []
        self.validation_loss_anomaly = []

    def configure_model(self):
        # Initialize small epsilone for numerical stability
        self.eps = 1e-6
        if self.cfg.model.name == "resnet":
            self.encoder = ResNetEncoder(
                self.cfg.model,
                in_size=self.cfg.dataset.size,
                hierarchies=self.cfg.loss.get("hierarchy", False),
            )
        else:
            raise NotImplementedError(f"{self.cfg.model.name} encoder not implemented.")
        self.criterion = self.cfg.loss.utils.get_loss(
            backbone_out_dim=self.encoder.out_dim,
            **OmegaConf.to_container(self.cfg.loss),
        )
        # self.register_buffer("center", torch.zeros(self.encoder.out_dim))

    def configure_datasets(self):
        # Create train and test sets
        (
            self.train_dataset,
            self.train_dataset_no_aug,
            self.validation_dataset,
        ) = get_datasets(
            self.cfg.dataset,
        )
        if self.cfg.model.pretrain:
            self.normalization_mean = [0.485, 0.456, 0.406]
            self.normalization_std = [0.229, 0.224, 0.225]
        else:
            self.normalization_mean = [0.5, 0.5, 0.5]
            self.normalization_std = [0.5, 0.5, 0.5]
        self.normalize = K.Normalize(
            mean=self.normalization_mean, std=self.normalization_std
        )
        self.denormalize = K.Denormalize(
            mean=self.normalization_mean, std=self.normalization_std
        )

        # Create training augmentations for contrastive learning
        augmentations = get_augmentation(
            self.cfg.augmentation,
            size=self.cfg.dataset.size,
        )
        self.augmentations = K.ImageSequential(*augmentations)

        # Create test time SimCLR augmentations
        test_augmentation = get_augmentation(
            self.cfg.augmentation, size=self.cfg.dataset.size, eval=True
        )
        self.test_augmentation = K.ImageSequential(
            *test_augmentation,
            same_on_batch=True,
        )

        # Create context augmentations for training
        low_frequency_samples = None
        if (
            self.cfg.loss.augmentation_class.name
            == "low_frequency_context_augmentation"
        ):
            n_augs = self.cfg.loss.n_context_augs
            train_generator = iter(
                self.train_dataloader(drop_last=False, augment=False)
            )
            low_frequency_samples = next(train_generator)[0][:n_augs]
            while low_frequency_samples.shape[0] < n_augs:
                low_frequency_samples = torch.cat(
                    [
                        low_frequency_samples,
                        next(train_generator)[0][
                            : n_augs - low_frequency_samples.shape[0]
                        ],
                    ]
                )

        context_augmentations = get_augmentation(
            self.cfg.loss.augmentation_class,
            low_frequency_samples=low_frequency_samples,
            size=self.cfg.dataset.size,
            n_context_augs=self.cfg.loss.n_context_augs,
        )
        self.context_augmentations = K.ImageSequential(
            *context_augmentations, same_on_batch=True
        )

        # Create context test time augmentations
        test_context_augmentation = get_augmentation(
            self.cfg.loss.augmentation_class,
            low_frequency_samples=low_frequency_samples,
            size=self.cfg.dataset.size,
            eval=True,
            n_context_augs=self.cfg.loss.n_context_augs,
        )
        self.test_context_augmentation = K.ImageSequential(
            *test_context_augmentation,
            same_on_batch=True,
        )

    def configure_optimizers(self):
        opt = create_optimizer(self.cfg.optimizer, self.parameters())
        if self.cfg.optimizer.scheduler is None:
            return opt
        scheduler = create_scheduler(
            self.cfg.optimizer.scheduler, opt, self.train_dataloader(), self.cfg.epochs
        )
        return {
            "optimizer": opt,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": (
                    "epoch" if self.cfg.optimizer.scheduler != "onecycle" else "step"
                ),
            },
        }
