import sys
import torch
import pickle
import numpy as np
import torch.nn as nn
from torch.autograd import Variable

torch.set_default_dtype(torch.float32)
device = torch.device('cuda:0')

snp_dim = 5000  # number of SNPs to consider (tune according to GPU memory limits)
ind_dim = 800  # total number of individuals (in-population + reference, see dataloaders below)
noise_dim = 0  # dimension of the random noise input vector for Attacker
Dnoise_dim = 30 # dimension of the random noise input vector for Defender
n_beacons = 10  # number of beacons or AAF vectors to sample in each run
bcloss = nn.BCELoss()
beta = Variable(torch.tensor(1.5))  # Beta captures the relative weight of privacy vs utility
gamma = Variable(torch.tensor(0.5))
Kp = 30

bits = pickle.load(open('In_Pop.pkl', 'rb'))
d = []
for b in bits:
    d += [b.tolist()]
d = np.array(d, dtype=int)[:, 0:snp_dim]
bits = []

bits = pickle.load(open('Not_In_Pop.pkl', 'rb'))
d_n = []
for b in bits:
    d_n += [b.tolist()]
d_n = np.array(d_n, dtype=int)[:, 0:snp_dim]
bits = []

U = np.concatenate((d, d_n), axis=0)
d = []
d_n = []

l2loss = nn.MSELoss()

print("Got here!")

##################################################
# Attacker's Neural Network
##################################################
class Attacker(nn.Module):
    def __init__(self):
        super(Attacker, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(snp_dim + noise_dim, 3400),
            nn.Linear(3400, 2000),
            nn.BatchNorm1d(2000),
            nn.ReLU(),
            nn.Linear(2000, ind_dim),
            nn.Sigmoid(),
        )
    def forward(self, x):
        outs = self.net(x)
        return outs

##################################################
# Defender's Neural Network
##################################################

class ScaledSigmoid(nn.Module):
    def __init__(self):
        super(ScaledSigmoid, self).__init__()

    def forward(self, x):
        return torch.sigmoid(x) - 0.5

class Defender(nn.Module):
    def __init__(self):
        super(Defender, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(ind_dim + Dnoise_dim, 1500),
            nn.Linear(1500, 1100),
            nn.BatchNorm1d(1100),
            nn.ReLU(),
            nn.Linear(1100, 500),
            nn.BatchNorm1d(500),
            nn.LeakyReLU(),
        )

        self.noise_head = nn.Sequential(
            nn.Linear(500, snp_dim),
            ScaledSigmoid(),
        )

    def forward(self, x):
        return self.noise_head(self.net(x))

def get_beacons(x, mask, flip):
    return Variable((x * torch.where(flip >= 0.5, -1, 1)) * torch.where(mask >= 0.5, 0, 1), requires_grad=True)

##############################
# Defender's Cost with Vector of Kappa
kappa_path = 'kappa_Vector.pt' # kappa_Vector.pt is the vector of kappas with consistent size
kappa = torch.load(kappa_path,map_location=device)
def cost_defender_vec(noise, sl, b):
    eps = 1e-6
    b = b.float()
    y_pred = torch.clamp(sl, eps, 1 - eps)
    m = torch.sum(b * torch.log(y_pred)) / (n_beacons)
    noise = noise.to(device)
    products = noise*kappa
    absoulte_product = torch.abs(products)
    t = torch.sum(absoulte_product) / (n_beacons)
    noise_nok = torch.sum(torch.abs(noise)) / (n_beacons*snp_dim)
    total_cost = t + m
    return total_cost, m, t, noise_nok

###################################
# Defender's Cost with Scalar kappa
def cost_defender(noise, sl, b):
    eps = 1e-6
    b = b.float()
    y_pred = torch.clamp(sl, eps, 1 - eps)
    m = torch.sum(b * torch.log(y_pred)) / (n_beacons)
    t = torch.sum(torch.abs(noise)) / (n_beacons)
    total_cost = (beta * t) + m
    return total_cost, m, t

###################################
# Attacker's Cost with 0<gamma<1
def cost_attacker(sl, b):
    y_pred=sl.float()
    b=b.float()
    total_loss = bcloss(y_pred,b)
    return total_loss

def get_noisy(aafs, noise):
    noisy_aafsout = torch.clamp(aafs + noise, 0, 1)
    return noisy_aafsout

def BiPrivacyLoss(sl, b):
    st = sl >= 0.5
    pl = torch.sum(st * b)/ n_beacons
    return pl

nnA = Attacker().to(device)
nnD = Defender().to(device)
optD = torch.optim.Adam(nnD.parameters(), lr=0.001, weight_decay=0.00001)
optA = torch.optim.Adam(nnA.parameters(), lr=0.0001, weight_decay=0.00001)
decayRate = 0.988
my_lr_schedulerD = torch.optim.lr_scheduler.ExponentialLR(optimizer=optD,gamma=decayRate)
my_lr_schedulerA = torch.optim.lr_scheduler.ExponentialLR(optimizer=optA, gamma=decayRate)


###########################
# LRTs Start
###########################

Utotal=np.concatenate((d, d_n), axis=0)
U_n = U
U_tensor = torch.tensor(U, dtype=torch.float32)
U_n_tensor = torch.tensor(U_n, dtype=torch.float32)
aafs_n = torch.mean(U_n_tensor, dim=0)
aafs_n = torch.tensor(aafs_n, requires_grad=True).to(device)

#################################################
# Adaptive-Threshold LRT Attack
def ADP_LRTA(noise_aafs, Kp):
    a_min = 0.0001
    a_max = 0.9999
    noise_aafs = torch.clamp(noise_aafs, min=a_min, max=a_max)
    aafs_n2 = torch.clamp(aafs_n, min=a_min, max=a_max)
    logLH = torch.log(aafs_n2 / noise_aafs).unsqueeze(1)
    logLH_not = torch.log((1 - aafs_n2) / (1 - noise_aafs)).unsqueeze(1)
    logLH = logLH.to(device)
    logLH_not = logLH_not.to(device)
    InPop = U_tensor.unsqueeze(0)
    InPop = InPop.to(device)
    NotInPop = U_n_tensor.unsqueeze(0)
    NotInPop = NotInPop.to(device)

    LRTscore = logLH * InPop + logLH_not * (1 - InPop)
    NotIn_LRTscore = logLH * NotInPop + logLH_not * (1 - NotInPop)
    NotIn_eta = torch.sum(NotIn_LRTscore, dim=2)

    percentile_val = torch.zeros(NotIn_eta.shape[0])
    for i in range(NotIn_eta.shape[0]):
        percentile_value_i = torch.quantile(NotIn_eta[i], Kp / 100.0)
        percentile_val[i] = percentile_value_i

    mean_tensor = torch.zeros([n_beacons, snp_dim],
                              device=device)
    for k in range(NotIn_eta.shape[0]):
        mask = NotIn_eta[k] <= percentile_val[k]
        indices = torch.where(mask)[0]
        filtered_slice = NotIn_LRTscore[k, indices]
        mean_k = torch.mean(filtered_slice, dim=0)
        mean_tensor[k] = mean_k
    inputs = LRTscore - mean_tensor.unsqueeze(1)
    inputs = torch.sum(inputs, dim=2)
    inputstat = inputs.float()
    SIG_LRT = torch.sigmoid(-inputstat)
    return SIG_LRT
###########################
# LRTs End
###########################

print("Defined everything!")

###############################################################################################################
i = 0
for epoch in range(0, 50000):

    optD.zero_grad()

    membership_vector = np.random.randint(0, 2, (n_beacons, ind_dim)).astype(float)
    b = torch.tensor(membership_vector, requires_grad=False).to(device)
    aafs = np.array([np.mean(U[np.where(mb == 1)], axis=0) for mb in
                     membership_vector])
    aafs = torch.tensor(aafs, requires_grad=False).to(device)
    r_d = torch.tensor(np.random.uniform(0, 1, (n_beacons, Dnoise_dim)), requires_grad=False).to(
        device)
    d_in = torch.cat((b, r_d), 1)

    nnD.train()
    d_in = d_in.float()
    noise = nnD(d_in)
    noisy_aafs = get_noisy(aafs, noise).to(device)
    st = ADP_LRTA(noisy_aafs, Kp)
    loss_D, m, t = cost_defender(noise, st, b)
    loss_D.backward()
    optD.step()

    nnD.eval()
    membership_vector2 = np.random.randint(0, 2, (n_beacons, ind_dim)).astype(float)
    b2 = torch.tensor(membership_vector2, requires_grad=False).to(
        device)
    aafs2 = np.array([np.mean(U[np.where(mb == 1)], axis=0) for mb in
                      membership_vector2])
    aafs2 = torch.tensor(aafs2, requires_grad=False).to(device)  #
    r_d2 = torch.tensor(np.random.uniform(0, 1, (n_beacons, Dnoise_dim)), requires_grad=False).to(
        device)
    d_in2 = torch.cat((b2, r_d2), 1)
    d_in2 = d_in2.float()
    noise2 = nnD(d_in2)
    noisy_aafs2 = get_noisy(aafs2, noise2).to(device)
    s2 = ADP_LRTA(noisy_aafs2, Kp)

    i += 1
    if epoch % 50 == 0 and epoch > 1:
        my_lr_schedulerD.step()

sys.exit()
















