import torch
import pygmtools as pygm
import random
import math
from torch.distributions import multivariate_normal
import functools
from statsmodels.stats.proportion import proportion_confint
from scipy.stats import norm
import time
import datetime
import os
from src.utils.print_easydict import print_easydict
import argparse
pygm.BACKEND = 'pytorch'
_ = torch.manual_seed(1)

def correlation_matrix(P1, length, batch_size):
    gamma = 5
    A1 = torch.ones(batch_size, length, length)

    for k in range(batch_size):
        for i in range(length):
            for j in range(i + 1, length):
                x0 = P1[k][i][0]
                x1 = P1[k][j][0]
                y0 = P1[k][i][1]
                y1 = P1[k][j][1]

                num = math.sqrt((x0 - x1) * (x0 - x1) + (y0 - y1) * (y0 - y1))
                A1[k][i][j] = 1 / (1 + num / gamma)
                A1[k][j][i] = 1 / (1 + num / gamma)
    return A1

def Sigma(B, noise_scale, length, s_dis):
    if s_dis == True:

        covar_result = torch.bmm(B, B)  # Sigma
        (evals_covar_ori, evecs_covar_ori) = torch.eig(covar_result[0], eigenvectors=True)
        min_sigma = torch.min(evals_covar_ori.T[0])
        covar_result = (covar_result / min_sigma) * noise_scale * noise_scale * 1.2  # normalization
        (evals_covar, evecs_covar) = torch.eig(covar_result[0], eigenvectors=True)
        pra = (1 / math.gamma(length * length / 2 + 1)) ** (1 / (length * length))
        prod_proxy = (math.pi * (1 / pra) ** (2)) * torch.prod(
            (evals_covar.T[0] ** (1 / (length * length))).reshape(-1))  # lvolumn

        inv_result = torch.inverse(covar_result[0])  # inverse of Sigma
        (evals, evecs) = torch.eig(inv_result, eigenvectors=True)
        min_lamda = torch.min(evals.T[0])
        max_lamda = torch.max(evals.T[0])

    else:
        covar_result = torch.bmm(B, B)  # Sigma
        prod_proxy = torch.tensor(noise_scale * noise_scale, dtype=float)
        min_lamda = max_lamda = 1 / prod_proxy

    return covar_result, prod_proxy, min_lamda, max_lamda


def model(batch_size,length,noise_scale,P1,X_gt,s_dis,num):

    A1_ori = correlation_matrix(P1, length, batch_size)
    A2 = torch.bmm(torch.bmm(X_gt.transpose(1, 2), A1_ori), X_gt)

    if s_dis==True:
        # joint distribution noise
        B_lower = torch.cat((torch.zeros(batch_size, length, length), A1_ori), 2)
        B_upper = torch.cat((A1_ori, torch.zeros(batch_size, length, length)), 2)
        B = torch.cat((B_upper, B_lower), 1)
    else:
        B = (torch.eye(length*2)*noise_scale).unsqueeze(0)

    cov,prod_proxy,min_lamda,max_lamda = Sigma(B,noise_scale,length,s_dis)
    multivar = multivariate_normal.MultivariateNormal(torch.zeros(batch_size,length * 2), cov)

    #smooth
    counts = torch.zeros([batch_size, length, length], dtype=torch.float)
    for _ in range(num):

        noise = multivar.sample().reshape(batch_size, length, 2)
        A1 = correlation_matrix(P1 + noise, length, batch_size)

        # Build affinity matrix
        n1 = n2 = torch.tensor([length] * batch_size)
        conn1, edge1, ne1 = pygm.utils.dense_to_sparse(A1)
        conn2, edge2, ne2 = pygm.utils.dense_to_sparse(A2)
        gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1.)  # set affinity function
        K = pygm.utils.build_aff_mat(None, edge1, conn1, None, edge2, conn2, n1, None, n2, None, edge_aff_fn=gaussian_aff)

        # Solve by RRWM. Note that X is normalized with a sum of 1
        X = pygm.rrwm(K, n1, n2, beta=100)
        X_hungarian=pygm.hungarian(X)
        counts += X_hungarian

    return counts,min_lamda,max_lamda,prod_proxy

def _lower_confidence_bound(NA, N, alpha):
    return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]

def certify(batch_size, length, noise_scale, P1, X_gt, s_dis, n0,n,alpha):

    ABSTAIN=-1

    counts_selection, min_lamda, max_lamda, prod_proxy = model(batch_size, length, noise_scale, P1, X_gt, s_dis, n0)
    # use these samples to take a guess at the top match
    cAHat = counts_selection.argmax(axis=2)
    counts_estimation, min_lamda, max_lamda, prod_proxy = model(batch_size, length, noise_scale, P1, X_gt, s_dis, n)

    result_Lvolume = []
    result_Llower = []
    result_Lmax = []
    for j in range(batch_size):
        for i in range(length):
            nA_item = counts_estimation[j][i][cAHat[j][i]]
            pABar_item = _lower_confidence_bound(nA_item, n, alpha)

            if pABar_item < 0.5:
                result_Lvolume.append([ABSTAIN, 0.0])
                result_Llower.append([ABSTAIN, 0.0])
                result_Lmax.append([ABSTAIN, 0.0])

            else:
                radius_Lvolume = torch.sqrt(prod_proxy) * norm.ppf(pABar_item) / 2
                radius_Llower = norm.ppf(pABar_item) / (2 * torch.sqrt(max_lamda))
                radius_Lmax = norm.ppf(pABar_item) / (2 * torch.sqrt(min_lamda))

                result_Lvolume.append([cAHat[j][i], radius_Lvolume])
                result_Llower.append([cAHat[j][i], radius_Llower])
                result_Lmax.append([cAHat[j][i], radius_Lmax])

        return result_Lvolume, result_Llower, result_Lmax

def output(prediction_Lvolume, prediction_Llower, prediction_Lmax,length,X_gt,f_volume,f_max,f_lower,number,time_elapsed):

    total_label = X_gt.argmax(axis=2)
    for i in range(batch_size):
        for j in range(length):
            for predict in [prediction_Lvolume,prediction_Llower,prediction_Lmax]:

                if prediction_Lvolume[batch_size*i+j][0] != -1:
                    if X_gt[i][j][prediction_Lvolume[batch_size*i+j][0]] == 1:
                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Lvolume[batch_size*i+j][0]), prediction_Lvolume[batch_size*i+j][1], 1,time_elapsed), file=f_volume, flush=True)
                    else:
                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Lvolume[batch_size * i + j][0]),prediction_Lvolume[batch_size * i + j][1], 0,time_elapsed), file=f_volume, flush=True)
                if prediction_Lvolume[batch_size*i+j][0] == -1:
                    print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Lvolume[batch_size * i + j][0]),prediction_Lvolume[batch_size * i + j][1], 0,time_elapsed), file=f_volume, flush=True)

                if prediction_Llower[batch_size*i+j][0] != -1:
                    if X_gt[i][j][prediction_Llower[batch_size*i+j][0]] == 1:
                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Llower[batch_size*i+j][0]), prediction_Llower[batch_size*i+j][1], 1,time_elapsed), file=f_lower, flush=True)
                    else:
                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Llower[batch_size * i + j][0]),prediction_Llower[batch_size * i + j][1], 0,time_elapsed), file=f_lower, flush=True)
                if prediction_Llower[batch_size*i+j][0] == -1:
                    print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Llower[batch_size * i + j][0]),prediction_Llower[batch_size * i + j][1], 0,time_elapsed), file=f_lower, flush=True)

                if prediction_Lmax[batch_size*i+j][0] != -1:
                    if X_gt[i][j][prediction_Lmax[batch_size*i+j][0]] == 1:
                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Lmax[batch_size*i+j][0]), prediction_Lmax[batch_size*i+j][1], 1,time_elapsed), file=f_max, flush=True)
                    else:
                        print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Lmax[batch_size * i + j][0]),prediction_Lmax[batch_size * i + j][1], 0,time_elapsed), file=f_max, flush=True)
                if prediction_Lmax[batch_size*i+j][0] == -1:
                    print("{}\t{}\t{}\t{:.3}\t{}\t{}".format(number, int(total_label[i][j]), int(prediction_Lmax[batch_size * i + j][0]),prediction_Lmax[batch_size * i + j][1], 0,time_elapsed), file=f_max, flush=True)

# Generate a batch of keypoints matrix
parser = argparse.ArgumentParser(description='Certify examples')
parser.add_argument("--method", choices=["rrwm"], default="rrwm", help="matching methods")
parser.add_argument("--n0", type=int, default=10)
parser.add_argument("--n", type=int, default=100, help="number of samples to use")
parser.add_argument("--sample_number", type=int, default=100, help="Normalization parameter")
args = parser.parse_args()
#parameters
n0=args.n0
n=args.n
method=args.method
sample_number=args.sample_number

if __name__ == '__main__':

    file_name ="rrwm_node"
    if os.path.exists("result_"+file_name)==False:
        os.mkdir("result_"+file_name)
    if os.path.exists("result_"+file_name+"/"+method)==False:
        os.mkdir("result_"+file_name+"/"+method)

    for noise_scale in [0.3,0.4,0.5]: #original sigma
        for s_dis in [True,False]: #smooth distribution
            if s_dis == False:
                f_volume = open("result_" + file_name + "/" + method + "/noise_scale" + str(noise_scale) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(sample_number) + '_RS_Lvolume', 'w')
                f_max = open("result_" + file_name + "/" + method + "/noise_scale" + str(noise_scale) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(sample_number) + '_RS_Lmax', 'w')
                f_lower = open("result_" + file_name + "/" + method + "/noise_scale" + str(noise_scale) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(sample_number) + '_RS_Llower', 'w')
            else:
                f_volume = open("result_" + file_name + "/" + method + "/noise_scale" + str(noise_scale) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(sample_number) + '_cov_Lvolume', 'w')
                f_max = open("result_" + file_name + "/" + method + "/noise_scale" + str(noise_scale) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(sample_number) + '_cov_Lmax', 'w')
                f_lower = open("result_" + file_name + "/" + method + "/noise_scale" + str(noise_scale) + '_n' + str(n) + '_n0' + str(n0) + '_sample' + str(sample_number) + '_cov_Llower', 'w')

            print("noise_scale",noise_scale,"s_dis",s_dis)
            print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f_volume, flush=True)
            print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f_max, flush=True)
            print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f_lower, flush=True)

            batch_size = 1
            for number in range(sample_number):
                print("sample_number: ",number)
                length = random.randint(5, 10)
                P1 = torch.rand(batch_size, length, 2) * 100
                X_gt = torch.zeros(batch_size, length, length)
                X_gt[:, torch.arange(0, length, dtype=torch.int64), torch.randperm(length)] = 1

                before_time = time.time()
                prediction_Lvolume, prediction_Llower, prediction_Lmax = certify(batch_size, length, noise_scale, P1, X_gt,
                                                                              s_dis, n0, n, alpha=0.01)
                after_time = time.time()
                time_elapsed = str(datetime.timedelta(seconds=(after_time - before_time)))
                output(prediction_Lvolume, prediction_Llower, prediction_Lmax, length, X_gt, f_volume, f_max, f_lower, number,time_elapsed)

            f_volume.close()
            f_max.close()
            f_lower.close()

