from abc import ABC

import numpy as np
import torch
from sklearn.metrics import accuracy_score
from torch.nn import functional as F

from src.algorithms.base.client_base import BaseClient
from src.utils.accuracy import compute_fnr, compute_auroc


class FOOGDClient(BaseClient, ABC):
    def __init__(self, client_args):
        super().__init__(client_args)
        self.score_model = client_args["score_model"]

    @torch.no_grad()
    def test_corrupt_accuracy(self, cor_loader):
        self.backbone.to(self.device)
        self.backbone.eval()

        accuracy = []
        for data, targets in cor_loader:
            if len(data) == 1:
                continue
            data, targets = data.to(self.device), targets.to(self.device)
            logit = self.backbone(data)
            pred = logit.data.max(1)[1]
            accuracy.append(accuracy_score(list(targets.data.cpu().numpy()), list(pred.data.cpu().numpy())))
        return sum(accuracy) / len(accuracy)

    def test_classification_detection_ability(self, id_loader, ood_loader, score_method="sm"):
        self.backbone.to(self.device)
        self.score_model.to(self.device)
        self.backbone.eval()
        self.score_model.eval()

        ood_score_id = []
        ood_score_ood = []
        accuracy = []

        with torch.no_grad():
            for data, target in id_loader:
                data, target = data.to(self.device), target.to(self.device)
                latents = self.backbone.intermediate_forward(data)
                logit = self.backbone(data)
                scores = self.score_model(latents).norm(dim=-1)
                pred = logit.data.max(1)[1]
                accuracy.append(accuracy_score(list(target.data.cpu().numpy()), list(pred.data.cpu().numpy())))

                if score_method == "energy":
                    ood_score_id.extend(list(-(1.0 * torch.logsumexp(logit / 1.0, dim=1)).data.cpu().numpy()))
                elif score_method == "msp":
                    ood_score_id.extend(list(np.max(F.softmax(logit, dim=1).cpu().numpy(), axis=1)))
                elif score_method == "sm":
                    ood_score_id.extend(list(scores.data.cpu().numpy()))

            for data, _ in ood_loader:
                data = data.to(self.device)
                latents = self.backbone.intermediate_forward(data)
                logit = self.backbone(data)
                scores = self.score_model(latents).norm(dim=-1)
                if score_method == "energy":
                    ood_score_ood.extend(list(-(1.0 * torch.logsumexp(logit / 1.0, dim=1)).data.cpu().numpy()))
                elif score_method == "msp":
                    ood_score_ood.extend(list(np.max(F.softmax(logit, dim=1).cpu().numpy(), axis=1)))
                elif score_method == "sm":
                    ood_score_ood.extend(list(scores.data.cpu().numpy()))

        if score_method == "energy":
            fpr95 = compute_fnr(np.array(ood_score_ood), np.array(ood_score_id))
            auroc = compute_auroc(np.array(ood_score_ood), np.array(ood_score_id))
        elif score_method == "msp":
            fpr95 = compute_fnr(np.array(ood_score_id), np.array(ood_score_ood))
            auroc = compute_auroc(np.array(ood_score_id), np.array(ood_score_ood))
        elif score_method == "sm":
            fpr95 = compute_fnr(np.array(ood_score_ood), np.array(ood_score_id))
            auroc = compute_auroc(np.array(ood_score_ood), np.array(ood_score_id))

        id_accuracy = sum(accuracy) / len(accuracy)

        return id_accuracy, fpr95, auroc
