import torch
from scipy.stats import norm, binom_test
import numpy as np
from math import ceil
from statsmodels.stats.proportion import proportion_confint
import math
from torch.distributions import multivariate_normal
import os
from src.utils.config import cfg
import random
import copy
import os

torch.cuda.set_device(0)

class Smooth(object):
    """A smoothed classifier g """

    # to abstain, Smooth returns this int
    ABSTAIN = -1

    def __init__(self, base_classifier: torch.nn.Module):

        self.base_classifier = base_classifier

    def certify(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int, clas: int, sigma_pro: int,
                sigma: float, ancer_theta: torch.tensor, if_ancer: bool,down:int,k:float) -> (int, float):

        self.base_classifier.eval()

        counts_selection, min_lamda,max_lamda,prod_proxy = self._sample_noise(x, n0, batch_size, clas, sigma_pro, ancer_theta, if_ancer,
                                                         sigma,down,k)
        # use these samples to take a guess at the top match
        cAHat = counts_selection.argmax(axis=1)
        counts_estimation, min_lamda,max_lamda,prod_proxy= self._sample_noise(x, n, batch_size, clas, sigma_pro, ancer_theta, if_ancer,
                                                          sigma,down,k)

        result_Lvolume = []
        result_Llower=[]
        result_Lmax=[]
        for i in range(clas):

            nA_item = counts_estimation[i][cAHat[i]].item()
            pABar_item = self._lower_confidence_bound(nA_item, n, alpha)

            if pABar_item < 0.5:
                result_Lvolume.append([Smooth.ABSTAIN, 0.0])
                result_Llower.append([Smooth.ABSTAIN, 0.0])
                result_Lmax.append([Smooth.ABSTAIN, 0.0])
            else:

                radius_Lvolume = torch.sqrt(prod_proxy)*norm.ppf(pABar_item) / 2
                radius_Llower = norm.ppf(pABar_item) / (2 * torch.sqrt(max_lamda))
                radius_Lmax = norm.ppf(pABar_item) / (2 * torch.sqrt(min_lamda))

                result_Lvolume.append([cAHat[i], radius_Lvolume])
                result_Llower.append([cAHat[i], radius_Llower])
                result_Lmax.append([cAHat[i], radius_Lmax])

        return result_Lvolume,result_Llower,result_Lmax


    def _sample_noise(self, x: torch.tensor, num: int, batch_size, num_classes, sigma_pro, ancer_theta: torch.tensor,
                      if_ancer: bool, sigma: float,down:int,k:float) -> np.ndarray:

        with torch.no_grad():
            if num <= 100:
                batch_size = 1

            counts = torch.zeros([batch_size, num_classes, num_classes], dtype=torch.float).cuda()

            point = x['Ps'] #the keypoint coordinate
            batch = copy.deepcopy(x)  # deepcopy,don't change the original x

            if num > 100:
                batch["images"][0] = batch["images"][0].repeat((batch_size, 1, 1, 1))
                batch["images"][1] = batch["images"][1].repeat((batch_size, 1, 1, 1))
                batch['Ps'][0] = batch['Ps'][0].repeat((batch_size, 1, 1))
                batch['Ps'][1] = batch['Ps'][1].repeat((batch_size, 1, 1))
                batch['ns'][0] = batch['ns'][0].repeat((batch_size))
                batch['ns'][1] = batch['ns'][1].repeat((batch_size))
                batch['As'][0] = batch['As'][0].repeat((batch_size, 1, 1))
                batch['As'][1] = batch['As'][1].repeat((batch_size, 1, 1))

            #get the covariance, min_lamda, max_lamda and pro_proxy of
            covar, min_lamda, max_lamda,prod_proxy= self.co_variance(point[0], sigma_pro, ancer_theta, if_ancer, sigma,down,k)
            length = len(covar)
            #get multivariate sample
            multivar = multivariate_normal.MultivariateNormal(torch.zeros(length).cuda(), covar)

            for _ in range(ceil(num / batch_size)):

                for i in range(0, batch_size):

                    noise = multivar.sample().reshape(int(length/2), 2)
                    noise_trans = noise.unsqueeze(0)
                    if i == 0:
                        noise_last = noise_trans
                    else:
                        noise_last = torch.cat((noise_last, noise_trans), 0)

                batch['Ps'][0] = point[0] + noise_last
                predictions = self.base_classifier(batch)
                counts += predictions['perm_mat']

            counts_result = torch.sum(counts, dim=0)
            return counts_result, min_lamda,max_lamda,prod_proxy


    def co_variance(self, point: torch.tensor, sigma_pro: int, ancer_theta: torch.tensor, if_ancer: bool,
                    sigma: float,down:int,k:float) -> torch.tensor:

        length = point.size()[1]
        if sigma_pro == 10 and if_ancer == False: #DDRS
            covar_result = torch.eye(length*2).cuda() * sigma

        if sigma_pro == 10 and if_ancer == True: #ANCER
            covar_result = torch.eye(length*2).cuda()
            for k in range(length*2):
                covar_result[k][k] = ancer_theta[k%length][k//length]

        if sigma_pro != 10 :
            min_ancer = torch.min(ancer_theta)  #SCR-GM
            B = torch.eye(length*2).cuda()
            for i in range(length):
                for j in range(i + 1, length):

                    x0 = point[0][i][0]
                    x1 = point[0][j][0]
                    y0 = point[0][i][1]
                    y1 = point[0][j][1]

                    num = math.sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1))
                    B[i][j] = 1 / (1 + num / down)
                    B[j][i] = 1 / (1 + num / down)
                    B[i+length][j+length] = 1 / (1 + num / down)
                    B[j+length][i+length] = 1 / (1 + num / down)

            B = self.select_sigma(B, sigma_pro)  # B
            inv_B = torch.inverse(B)
            covar_result = torch.matmul(B, B)  # Sigma

            (evals_covar_ori, evecs_covar_ori) = torch.eig(covar_result, eigenvectors=True)
            min_sigma = torch.min(evals_covar_ori.T[0])
            ancer_sigma = torch.prod(ancer_theta ** (1 / (2 * length)))
            covar_result = (covar_result / min_sigma) * (ancer_sigma * ancer_sigma)

        (evals_covar, evecs_covar) = torch.eig(covar_result, eigenvectors=True)
        pra = (1 / math.gamma(length + 1)) ** (1 / (length * 2))
        prod_proxy = (math.pi * (pra ** (2))) * torch.prod((evals_covar.T[0] ** (1 / (2 * length))).reshape(-1))
        inv_result = torch.inverse(covar_result)  # inverse of Sigma
        (evals, evecs) = torch.eig(inv_result, eigenvectors=True)
        min_lamda = torch.min(evals.T[0])
        max_lamda = torch.max(evals.T[0])

        return covar_result, min_lamda,max_lamda,prod_proxy

    def select_sigma(self, Sigma: torch.tensor, sigma_pro: int) -> torch.tensor:

        length = Sigma.size()[1]
        Sigma_single = Sigma.view(-1)
        i, idx = Sigma_single.sort(dim=0, descending=False)

        for j in range(0, min(length * (length - 1) * sigma_pro // 20, length * (length - 1) // 2)):
            coordinate1 = idx[j * 2] // length
            coordinate2 = idx[j * 2] % length
            Sigma[coordinate1][coordinate2] = 0
            Sigma[coordinate2][coordinate1] = 0

        return Sigma

    def _lower_confidence_bound(self, NA: int, N: int, alpha: float) -> float:
        
        return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]