import sys
import torch
import pickle
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torchmetrics import ROC, PrecisionRecallCurve

torch.set_default_dtype(torch.float32)
device = torch.device('cuda:0')

snp_dim = 5000  # number of SNPs to consider (tune according to GPU memory limits)
ind_dim = 800  # total number of individuals (in-population + reference, see dataloaders below)
noise_dim = 0  # dimension of the random noise input vector for Attacker
Dnoise_dim = 30 # dimension of the random noise input vector for Defender
n_beacons = 10  # number of beacons or AAF vectors to sample in each run
bcloss = nn.BCELoss()
beta = Variable(torch.tensor(1.5))  # Beta captures the relative weight of privacy vs utility
gamma = Variable(torch.tensor(0.5))

bits = pickle.load(open('In_Pop.pkl', 'rb'))
d = []
for b in bits:
    d += [b.tolist()]
d = np.array(d, dtype=int)[:, 0:snp_dim]
bits = []

bits = pickle.load(open('Not_In_Pop.pkl', 'rb'))
d_n = []
for b in bits:
    d_n += [b.tolist()]
d_n = np.array(d_n, dtype=int)[:, 0:snp_dim]
bits = []

U = np.concatenate((d, d_n), axis=0)
d = []
d_n = []

l2loss = nn.MSELoss()

print("Got here!")

##################################################
# Attacker's Neural Network
##################################################
class Attacker(nn.Module):
    def __init__(self):
        super(Attacker, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(snp_dim + noise_dim, 3400),
            nn.Linear(3400, 2000),
            nn.BatchNorm1d(2000),
            nn.ReLU(),
            nn.Linear(2000, ind_dim),
            nn.Sigmoid(),
        )
    def forward(self, x):
        outs = self.net(x)
        return outs

##################################################
# Defender's Neural Network
##################################################

class ScaledSigmoid(nn.Module):
    def __init__(self):
        super(ScaledSigmoid, self).__init__()

    def forward(self, x):
        return torch.sigmoid(x) - 0.5

class Defender(nn.Module):
    def __init__(self):
        super(Defender, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(ind_dim + Dnoise_dim, 1500),
            nn.Linear(1500, 1100),
            nn.BatchNorm1d(1100),
            nn.ReLU(),
            nn.Linear(1100, 500),
            nn.BatchNorm1d(500),
            nn.LeakyReLU(),
        )

        self.noise_head = nn.Sequential(
            nn.Linear(500, snp_dim),
            ScaledSigmoid(),
        )

    def forward(self, x):
        return self.noise_head(self.net(x))

def get_beacons(x, mask, flip):
    return Variable((x * torch.where(flip >= 0.5, -1, 1)) * torch.where(mask >= 0.5, 0, 1), requires_grad=True)

##############################
# Defender's Cost with Vector of Kappa
kappa_path = 'kappa_Vector.pt' # kappa_Vector.pt is the vector of kappas with consistent size
kappa = torch.load(kappa_path,map_location=device)
def cost_defender_vec(noise, sl, b):
    eps = 1e-6
    b = b.float()
    y_pred = torch.clamp(sl, eps, 1 - eps)
    m = torch.sum(b * torch.log(y_pred)) / (n_beacons)
    noise = noise.to(device)
    products = noise*kappa
    absoulte_product = torch.abs(products)
    t = torch.sum(absoulte_product) / (n_beacons)
    noise_nok = torch.sum(torch.abs(noise)) / (n_beacons*snp_dim)
    total_cost = t + m
    return total_cost, m, t, noise_nok

###################################
# Defender's Cost with Scalar kappa
def cost_defender(noise, sl, b):
    eps = 1e-6
    b = b.float()
    y_pred = torch.clamp(sl, eps, 1 - eps)
    m = torch.sum(b * torch.log(y_pred)) / (n_beacons)
    t = torch.sum(torch.abs(noise)) / (n_beacons)
    total_cost = (beta * t) + m
    return total_cost, m, t

###################################
# Attacker's Cost with 0<gamma<1
def cost_attacker(sl, b):
    y_pred=sl.float()
    b=b.float()
    total_loss = bcloss(y_pred,b)
    return total_loss

def get_noisy(aafs, noise):
    noisy_aafsout = torch.clamp(aafs + noise, 0, 1)
    return noisy_aafsout

def BiPrivacyLoss(sl, b):
    st = sl >= 0.5
    pl = torch.sum(st * b)/ n_beacons
    return pl


nnA = Attacker().to(device)
nnD = Defender().to(device)
optD = torch.optim.Adam(nnD.parameters(), lr=0.001, weight_decay=0.00001)
optA = torch.optim.Adam(nnA.parameters(), lr=0.0001, weight_decay=0.00001)
decayRate = 0.988
my_lr_schedulerD = torch.optim.lr_scheduler.ExponentialLR(optimizer=optD,gamma=decayRate)
my_lr_schedulerA = torch.optim.lr_scheduler.ExponentialLR(optimizer=optA, gamma=decayRate)

print("Defined everything!")

###############################################################################################################
i = 0
for epoch in range(0, 50000):

    optD.zero_grad()
    optA.zero_grad()

    ##################################################
    # Train Defender's neural netowrk
    membership_vector = np.random.randint(0, 2, (n_beacons, ind_dim)).astype(float)
    b = torch.tensor(membership_vector, requires_grad=False).to(device)
    aafs = np.array([np.mean(U[np.where(mb == 1)], axis=0) for mb in
                     membership_vector])
    aafs = torch.tensor(aafs, requires_grad=False).to(device)
    r_d = torch.tensor(np.random.uniform(0, 1, (n_beacons, Dnoise_dim)), requires_grad=False).to(
        device)
    d_in = torch.cat((b, r_d), 1)

    nnD.train()
    nnA.eval()

    d_in = d_in.float()
    noise = nnD(d_in)
    noisy_aafs = get_noisy(aafs, noise).to(device)

    r_a = torch.tensor(np.random.uniform(0, 1, (n_beacons, noise_dim)), requires_grad=False).to(device)
    a_in = torch.cat((torch.clip(noisy_aafs, min=0.001, max=0.999), r_a),1)
    a_in = a_in.float()
    s = nnA(a_in)
    loss_D, m, t = cost_defender(noise, s, b)
    if (epoch % 200) > 100:
        loss_D.backward()
        optD.step()

    auc = torch.mean(ROC(s, b))

    nnD.eval()
    nnA.train()

    ##################################################
    # Train Attacker's neural netowrk
    membership_vector2 = np.random.randint(0, 2, (n_beacons, ind_dim)).astype(float)
    b2 = torch.tensor(membership_vector2, requires_grad=False).to(device)
    aafs2 = np.array([np.mean(U[np.where(mb == 1)], axis=0) for mb in membership_vector2])
    aafs2 = torch.tensor(aafs2, requires_grad=False).to(device)  #
    r_d2 = torch.tensor(np.random.uniform(0, 1, (n_beacons, Dnoise_dim)), requires_grad=False).to(device)
    d_in2 = torch.cat((b2, r_d2), 1)
    d_in2 = d_in2.float()
    noise2 = nnD(d_in2)
    noisy_aafs2 = get_noisy(aafs2, noise2).to(device)
    r_a2 = torch.tensor(np.random.uniform(0, 1, (n_beacons, noise_dim)), requires_grad=False).to(device)
    a_in2 = torch.cat((torch.clip(noisy_aafs2, min=0.001, max=0.999), r_a2),
                      1)
    a_in2 = a_in2.float()
    s2 = nnA(a_in2)
    loss_A = cost_attacker(s2, b2)
    if (epoch % 200) <= 100:
        loss_A.backward()
        optA.step()

    i += 1
    if epoch % 50 == 0 and epoch > 1:
        my_lr_schedulerD.step()
        my_lr_schedulerA.step()


sys.exit()




















