import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.nn.functional import cosine_similarity
from xad.training.ad_trainer import XADTrainer
from sklearn.cluster import k_means


class DSVDDTrainer(XADTrainer):
    """ implements deep support vector description to perform unsupervised anomaly detection """

    def prepare_metric(self, cstr: str, loader: DataLoader, model: torch.nn.Module, seed: int, **kwargs) -> torch.Tensor:
        center = []
        eps = kwargs.get('eps', 1e-8)
        for imgs, lbls, _ in tqdm(loader, desc=f'cls {cstr} preparing DSVDD center'):
            imgs = imgs.to(self.device)
            with torch.no_grad():
                image_features = model(imgs[lbls == 0])
            center.append(image_features.cpu().mean(0).unsqueeze(0))
        center = torch.cat(center).mean(0).unsqueeze(0).to(self.device)
        center[(abs(center) < eps) & (center < 0)] = -eps
        center[(abs(center) < eps) & (center > 0)] = eps
        return center

    def compute_anomaly_score(self, features: torch.Tensor, center: torch.Tensor, train: bool = False, **kwargs) -> torch.Tensor:
        return (features - center).pow(2).sum(-1).div(1 + (features - center).pow(2).sum(-1))

    def loss(self, features: torch.Tensor, labels: torch.Tensor,  center: torch.Tensor, **kwargs) -> torch.Tensor:
        return (features - center).pow(2).sum(-1).div(1 + (features - center).pow(2).sum(-1)).mean()


class MultiDSVDDTrainer(XADTrainer):
    """ implements deep support vector description to perform unsupervised anomaly detection """

    def prepare_metric(self, cstr: str, loader: DataLoader, model: torch.nn.Module, seed: int, **kwargs) -> torch.Tensor:
        center = []
        eps = kwargs.get('eps', 1e-4)
        for imgs, lbls, _ in tqdm(loader, desc=f'cls {cstr} preparing MultiDSVDD centers'):
            imgs = imgs.to(self.device)
            with torch.no_grad():
                image_features = model(imgs[lbls == 0])
            center.append(image_features.cpu())
        clusters = torch.from_numpy(k_means(torch.cat(center), 2)[0]).to(self.device).float()
        clusters[(abs(clusters) < eps) & (clusters < 0)] = -eps
        clusters[(abs(clusters) < eps) & (clusters > 0)] = eps
        return clusters

    def compute_anomaly_score(self, features: torch.Tensor, center: torch.Tensor, train: bool = False, **kwargs) -> torch.Tensor:
        return (features - center.unsqueeze(1)).pow(2).sum(-1).div(1 + (features - center.unsqueeze(1)).pow(2).sum(-1)).min(0)[0]

    def loss(self, features: torch.Tensor, labels: torch.Tensor,  center: torch.Tensor, **kwargs) -> torch.Tensor:
        return self.compute_anomaly_score(features, center, True, **kwargs).mean()
