import torch
import torch.nn as nn
import torch.nn.functional as F
from models.smooth import Smooth
from models.torch_utils import join
from functorch import vmap
from toolz.curried import partial

class TramerTransform(nn.Module):
    """
    Given a trained classifier, uses the Tramer transformation to create a detector.
    The adversarial attack (specified by the attacker) is a proxy for the robustness
    criterion used in Tramer's paper (https://arxiv.org/abs/2107.11630).
    """
    def __init__(self, model, attacker, invert=False, **kwargs):
        super().__init__()
        self.model = model
        self.attacker = attacker

        self.invert = invert

    def forward(self, x):
        predictions_x = self.model(x).argmax(dim=1)

        _, x_adv, _ = self.attacker(
            self.model,
            (x, predictions_x),
        )
        predictions_x_adv = self.model(x_adv).argmax(dim=1)

        # Negative prediction signaling rejection
        if self.invert:
            predictions = predictions_x
            predictions[predictions_x == predictions_x_adv] = -1
        else:
            predictions = predictions_x
            predictions[predictions_x != predictions_x_adv] = -1

        return predictions.long()


# Based on Cohen et al., https://github.com/locuslab/smoothing
class RandomizedSmoothing(nn.Module):
    def __init__(
            self,
            model,
            attacker,
            rejection_method="tramer",
            epsilon_defense=0.015,
            sigma_eps_scale=0.5,
            num_classes=10,
            n=1000,
            n_certificate=10000,
            alpha=0.05,
            batch_size=128,
            **kwargs
    ):
        super().__init__()
        self.model = model
        self.attacker = attacker
        self.epsilon_defense = epsilon_defense
        self.sigma = self.epsilon_defense * sigma_eps_scale
        self.n = n
        self.n_certificate = n_certificate
        self.num_classes = num_classes
        self.alpha = alpha
        self.batch_size = batch_size 
        self.rejection_method = rejection_method

        self.smoothed = Smooth(self.model, num_classes=self.num_classes, sigma=self.sigma, **kwargs)

        def predict(x):
            predictions = torch.empty(x.shape[0], dtype=torch.long, device=x.device)

            for i in range(x.shape[0]):
                predictions[i] = self.smoothed.predict(x[i], n=self.n, alpha=self.alpha, batch_size=self.batch_size)

            return predictions
        
        self.smoothed_predict = predict

        self.tramer_transformed = TramerTransform(self.model, self.attacker)

        def certificate_rejection(x):
            predictions = torch.empty(x.shape[0], dtype=torch.long, device=x.device)

            for i in range(x.shape[0]):
                prediction, radius = self.smoothed.certify(x[i], n_0=n, n=n_certificate, alpha=self.alpha, batch_size=self.batch_size)

                predictions[i] = -1 if radius < self.epsilon_defense else prediction

            return predictions

        self.certificate_rejection = certificate_rejection
        
    def forward(self, x):
        predictions = self.smoothed_predict(x)

        match self.rejection_method:
            case "tramer":
                rejected = self.tramer_transformed(x) < 0
            case "certificate":
                rejected = self.certificate_rejection(x) < 0
            case _:
                rejected = torch.zeros_like(predictions).bool()

        return ~rejected * predictions - rejected * torch.ones_like(predictions)


class Rejectron(nn.Module):
    """
    Based on Goldwasser et al.'s implementation of their Rejectron algorithm.
    Uses a fixed classifier and a transductively learned discriminator to
    perform selective classification.

    Tau determines the discriminator's confidence level that the data is from Q
    needed for rejection.

    See https://papers.nips.cc/paper/2020/hash/b6c8cf4c587f2ead0c08955ee6e2502b-Abstract.html.
    """
    def __init__(self, classifier, discriminator, tau=0.5, **kwargs):
         super().__init__()
         self.classifier = classifier
         self.discriminator = discriminator
         self.tau = tau

    def forward(self, x):
        predictions = self.classifier(x).argmax(dim=-1)
        q_prob = torch.sigmoid(self.discriminator(x).squeeze())
        reject = q_prob > self.tau

        predictions[reject] = -1

        return predictions


class DisagreementRejection(nn.Module):
    """
    Similar to the ideas of Goldwasser (but in a single-step form), attempts to train
    two (or more) identical networks with similar test predictions but differing
    predictions on the evaluation set. Should only be used transductively.

    When in training mode, the model will return a pair of outputs, one per model.
    When in evaluation mode, returns the class predictions (with negative values
    representing rejection), as above.

    agreement_needed specifies the fraction of models needed to agree on a prediction
    to accept.
    """
    def __init__(self, model_generator, count=2, agreement_needed=1.):
        super().__init__()
        self.models = nn.ModuleList([model_generator() for _ in range(count)])
        self.agreement_needed = agreement_needed
        self.output_intermediate = True

    def forward(self, x):
        logits = [m(x) for m in self.models]

        if self.output_intermediate:
            return logits

        logits = join(*logits, dim=1)
        predictions = logits.argmax(dim=-1)
        top_predictions = predictions.mode(keepdim=True).values
        matches = predictions == top_predictions

        agreement_rate = matches.float().mean(dim=-1)
        disagreed = agreement_rate < self.agreement_needed

        predictions = top_predictions
        predictions[disagreed] = -1

        return predictions.squeeze()
