import torch
from scipy.stats import norm, binom_test
import numpy as np
from math import ceil
from statsmodels.stats.proportion import proportion_confint
import math
from torch.distributions import multivariate_normal
import os
from src.utils.config import cfg
import random
import copy
import os

torch.cuda.set_device(0)


class Smooth(object):
    """A smoothed classifier g """

    # to abstain, Smooth returns this int
    ABSTAIN = -1

    def __init__(self, base_classifier: torch.nn.Module):

        self.base_classifier = base_classifier

    def certify(self, x: torch.tensor, n0: int, n: int, alpha: float, batch_size: int, clas: int, sigma_pro: int,
                sigma: float, if_ancer: bool, down: int, k: float) -> (int, float):

        self.base_classifier.eval()

        counts_selection, min_lamda, max_lamda, prod_proxy = self._sample_noise(x, n0, batch_size, clas, sigma_pro, if_ancer,
                                                                                sigma, down, k)
        # use these samples to take a guess at the top match
        cAHat = counts_selection.argmax(axis=1)
        counts_estimation, min_lamda, max_lamda, prod_proxy = self._sample_noise(x, n, batch_size, clas, sigma_pro, if_ancer,
                                                                                 sigma, down, k)

        result_Lvolume = []
        result_Llower = []
        result_Lmax = []
        for i in range(clas):

            nA_item = counts_estimation[i][cAHat[i]].item()
            pABar_item = self._lower_confidence_bound(nA_item, n, alpha)

            if pABar_item < 0.5:
                result_Lvolume.append([Smooth.ABSTAIN, 0.0])
                result_Llower.append([Smooth.ABSTAIN, 0.0])
                result_Lmax.append([Smooth.ABSTAIN, 0.0])
            else:
                radius_Lvolume = torch.sqrt(prod_proxy) * norm.ppf(pABar_item) / 2
                radius_Llower = norm.ppf(pABar_item) / (2 * torch.sqrt(max_lamda))
                radius_Lmax = norm.ppf(pABar_item) / (2 * torch.sqrt(min_lamda))

                result_Lvolume.append([cAHat[i], radius_Lvolume])
                result_Llower.append([cAHat[i], radius_Llower])
                result_Lmax.append([cAHat[i], radius_Lmax])
                print(radius_Lvolume,radius_Lmax,radius_Llower)

        return result_Lvolume, result_Llower, result_Lmax

    def _sample_noise(self, x: torch.tensor, num: int, batch_size, num_classes, sigma_pro,
                      if_ancer: bool, sigma: float, down: int, k: float) -> np.ndarray:

        with torch.no_grad():
            if num <= 100:
                batch_size = 1

            counts = torch.zeros([batch_size, num_classes, num_classes], dtype=torch.float).cuda()

            image = x['images']  # the image
            point = x['Ps']  # the keypoint coordinate
            batch = copy.deepcopy(x)  # deepcopy,don't change the original x

            if num > 100:
                batch["images"][0] = batch["images"][0].repeat((batch_size, 1, 1, 1))
                batch["images"][1] = batch["images"][1].repeat((batch_size, 1, 1, 1))
                batch['Ps'][0] = batch['Ps'][0].repeat((batch_size, 1, 1))
                batch['Ps'][1] = batch['Ps'][1].repeat((batch_size, 1, 1))
                batch['ns'][0] = batch['ns'][0].repeat((batch_size))
                batch['ns'][1] = batch['ns'][1].repeat((batch_size))
                batch['As'][0] = batch['As'][0].repeat((batch_size, 1, 1))
                batch['As'][1] = batch['As'][1].repeat((batch_size, 1, 1))

            if sigma_pro==0:
                mask, min_lamda, max_lamda, prod_proxy = self.find_mask(image[0],point[0],sigma)

            else:
                mask = torch.ones(image[0].size())*sigma
                min_lamda = torch.tensor(1/(sigma*sigma),dtype=float)
                max_lamda = torch.tensor(1/(sigma*sigma),dtype=float)
                prod_proxy = torch.tensor(sigma*sigma,dtype=float)

            for _ in range(ceil(num / batch_size)):

                for i in range(0, batch_size):

                    noise = torch.randn_like(image[0][0])* mask[0].cuda()
                    noise_trans = noise.unsqueeze(0)
                    if i == 0:
                        noise_last = noise_trans
                    else:
                        noise_last = torch.cat((noise_last, noise_trans), 0)
                #print(batch['images'][0].size(),image[0].size(),noise_last.size())
                batch['images'][0] = image[0] + noise_last
                predictions = self.base_classifier(batch)
                counts += predictions['perm_mat']

            counts_result = torch.sum(counts, dim=0)
            return counts_result, min_lamda, max_lamda, prod_proxy

    def find_mask(self,image_left:torch.tensor,point_left:torch.tensor,sigma:torch.tensor):

        length = point_left.size()[1]

        #cov
        distance_matrix = torch.eye(length).cuda()
        for i in range(length):
            for j in range(i + 1, length):
                x0 = point_left[0][i][0]
                x1 = point_left[0][j][0]
                y0 = point_left[0][i][1]
                y1 = point_left[0][j][1]

                num = math.sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1))
                distance_matrix[i][j] = num
                distance_matrix[j][i] = num
        distance_matrix = torch.sum(distance_matrix, dim=1)
        num_sum = torch.sum(distance_matrix,dim=0)
        B = (distance_matrix / torch.min(distance_matrix))*1.1
        mask = torch.ones(image_left.size())*1.05

        for i in range(length):
            length_num=20
            x_point = point_left[0][i][0]
            y_point = point_left[0][i][1]
            for j in range(length_num):
                for k in range(length_num):
                    if x_point-j>0 and y_point-k>0 and x_point-j<256 and y_point-k<256:
                        mask[0][0][int(x_point-j)][int(y_point-k)]=B[i]
                        mask[0][1][int(x_point-j)][int(y_point-k)]=B[i]
                        mask[0][2][int(x_point-j)][int(y_point-k)]=B[i]
                    if x_point+j<256 and y_point+k<256 and x_point+j>0 and y_point+k>0:
                        mask[0][0][int(x_point + j)][int(y_point + k)] = B[i]
                        mask[0][1][int(x_point + j)][int(y_point + k)] = B[i]
                        mask[0][2][int(x_point + j)][int(y_point + k)] = B[i]

        mask=mask*sigma
        mask_max = torch.max(mask)
        mask_min = torch.min(mask)
        min_lamda = 1/(mask_max*mask_max)
        max_lamda = 1/(mask_min*mask_min)

        prod_proxy_trans = torch.prod((mask ** (1 / (256 * 256 * 3))).reshape(-1))
        prod_proxy = (prod_proxy_trans*prod_proxy_trans).float()

        return mask.sqrt(),min_lamda,max_lamda,prod_proxy

    def _lower_confidence_bound(self, NA: int, N: int, alpha: float) -> float:

        return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]