import torch
from pyod.models.iforest import IForest
from pyod.models.copod import COPOD
from pyod.models.ecod import ECOD
from pyod.models.suod import SUOD
from pyod.models.mad import MAD


def min_max_normalization(x):
    x_min = torch.min(x)
    x_max = torch.max(x)
    norm = (x - x_min) / (x_max - x_min)
    return norm


class CognitiveDistillationAnalysis():
    def __init__(self, od_type='l1_norm', norm_only=False):
        self.od_type = od_type
        self.norm_only = norm_only
        self.clf = None
        self.mean = None
        self.std = None
        if 'Ensemble' in od_type:
            detector_list = [IForest(n_estimators=100),
                             IForest(n_estimators=500),
                             ECOD(),
                             COPOD()]
            self.clf = SUOD(base_estimators=detector_list, combination='average', verbose=False)
        elif 'MAD' in od_type:
            self.clf = MAD()
        return

    def train(self, data):
        if 'MAD' in self.od_type:
            data = torch.norm(data, dim=[1, 2, 3], p=1).view(-1, 1)
        if 'Ensemble' in self.od_type or 'MAD' in self.od_type:
            self.clf.fit(data)
        else:
            if not self.norm_only:
                data = torch.norm(data, dim=[1, 2, 3], p=1)
            self.mean = torch.mean(data).item()
            self.std = torch.std(data).item()
        return

    def predict(self, data, t=1):
        if 'MAD' in self.od_type:
            data = torch.norm(data, dim=[1, 2, 3], p=1).view(-1, 1)
        if 'Ensemble' in self.od_type or 'MAD' in self.od_type:
            return self.clf.predict(data)
        if not self.norm_only:
            data = torch.norm(data, dim=[1, 2, 3], p=1)
        p = (self.mean - data) / self.std
        p = torch.where((p > t) & (p > 0), 1, 0)
        return p.numpy()

    def analysis(self, data, is_test=False):
        """
            data (torch.tensor) b,c,h,w
            data is the distilled mask or pattern extracted by CognitiveDistillation (torch.tensor)
        """
        if 'MAD' in self.od_type:
            data = torch.norm(data, dim=[1, 2, 3], p=1).view(-1, 1)
        if 'Ensemble' in self.od_type or 'MAD' in self.od_type:
            if is_test:
                return self.clf.predict(data)
            else:
                return self.clf.decision_scores_
        else:
            if self.norm_only:
                if len(data.shape) > 1:
                    data = torch.norm(data, dim=[1, 2, 3], p=1)
                score = data
            else:
                score = torch.norm(data, dim=[1, 2, 3], p=1)
        score = min_max_normalization(score)
        return 1 - score.numpy()  # Lower for BD
