import torch
import scipy.stats as stats

from convexrobust.utils import torch_utils as TU

from convexrobust.model.base_certifiable import BaseCertifiable, Certificate, Norm

import lib.smoothingSplittingNoise.src.smooth as rs_smooth
import lib.smoothingSplittingNoise.src.noises as rs_noises


def custom_loss(preds, targets):
    return -torch.distributions.Categorical(logits=preds).log_prob(targets).mean()


class RandsmoothCertifiable(BaseCertifiable):
    def __init__(self, sigma=0.25, n0=100, n=100000, nb=100, alpha=0.001, cert_n_scale=1,
                 noise='gaussian', **kwargs):
        super().__init__(single_logit=False, custom_loss=custom_loss, **kwargs)
        # super().__init__(single_logit=False, **kwargs)

        args = {'dim': self.datamodule.in_n, 'sigma': sigma, 'device': TU.device()}

        noises = {
            'gaussian': rs_noises.Gaussian, 'uniform': rs_noises.Uniform,
            'laplace': rs_noises.Laplace, 'split': rs_noises.SplitMethod,
            'split_derandomized': rs_noises.SplitMethodDerandomized
        }

        self.noise = noises[noise](**args)

        self.n0 = n0
        self.n = n
        self.nb = nb
        self.alpha = alpha
        self.cert_n_scale = cert_n_scale

    def training_signal_modify(self, x):
        # return x
        return self.noise.sample(x.view(len(x), -1)).view(x.shape)
        # return self.noise.sample(x)

    def predict(self, x_batch, n=None, n_scale=1):
        n_scale = int(n_scale)
        self.eval()

        def predict_single(x):
            if isinstance(self.noise, rs_noises.SplitMethodDerandomized):
                counts = rs_smooth.smooth_predict_hard_derandomized(
                    self.forward_balanced, x, self.noise, noise_batch_size=self.nb)
                cats, _ = self.noise.classify_and_certify_l1_exact_from_counts(counts)
                return cats.squeeze(0)

            counts = rs_smooth.smooth_predict_hard(
                self.forward_balanced, x, self.noise, self.n if n is None else n,
                raw_count=True, noise_batch_size=self.nb
            )
            (na, nb), (ca, cb) = torch.topk(counts.int(), 2)

            if stats.binomtest(na * n_scale, (na + nb) * n_scale, 0.5).pvalue <= self.alpha:
                return ca

            return torch.tensor(-1).type_as(x)

        preds = []
        for x in x_batch:
            preds.append(predict_single(x.unsqueeze(0)))

        return torch.stack(preds)

    def certify(self, x, _):
        assert x.shape[0] == 1

        self.eval()

        if isinstance(self.noise, rs_noises.SplitMethodDerandomized):
            counts = rs_smooth.smooth_predict_hard_derandomized(
                self.forward_balanced, x, self.noise, noise_batch_size=self.nb)
            cats, certs = self.noise.classify_and_certify_l1_exact_from_counts(counts)
            certificate = Certificate.from_l1(certs.item(), self.datamodule.in_n)
            return cats, certificate

        preds = rs_smooth.smooth_predict_hard(self.forward_balanced, x, self.noise, self.n0)
        top_cats = preds.probs.argmax(dim=1)
        prob_lb = rs_smooth.certify_prob_lb(
            self.forward_balanced, x, top_cats, 2 * self.alpha, self.noise,
            self.n, noise_batch_size=self.nb, sample_size_scale=self.cert_n_scale
        )

        if prob_lb > 0.5:
            certificate = Certificate({
                Norm.L1: self.noise.certify_l1(prob_lb).item(),
                Norm.L2: self.noise.certify_l2(prob_lb).item(),
                Norm.LInf: self.noise.certify_linf(prob_lb).item(),
            })

            return top_cats, certificate

        return torch.tensor([-1]).type_as(x), Certificate.zero()
