import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from sklearn import preprocessing
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score
import sys,os
sys.path.append(r"/home/yh579/GAFM/GAFM/models")
from bases import FirstNet,SecondNet,torch_auc,totalvaraition,Attacks
# SplitNN
import torch


# SplitNN
import torch

device='cpu'
class Client(torch.nn.Module):
    def __init__(self, client_model):
        super().__init__()
        """class that expresses the Client on SplitNN
        Args:
            client_model (torch model): client-side model
        Attributes:
            client_model (torch model): cliet-side model
            client_side_intermidiate (torch.Tensor): output of
                                                     client-side model
            grad_from_server
        """

        self.client_model = client_model
        self.client_side_intermidiate = None
        self.grad_from_server = None

    def forward(self, inputs):
        """client-side feed forward network
        Args:
            inputs (torch.Tensor): the input data
        Returns:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model which the client sent
                                                   to the server
        """

        self.client_side_intermidiate = self.client_model(inputs)
        # send intermidiate tensor to the server
        intermidiate_to_server = self.client_side_intermidiate.detach() \
            .requires_grad_()

        return intermidiate_to_server

    def client_backward(self, grad_from_server):
        """client-side back propagation
        Args:
            grad_from_server: gradient which the server send to the client
        """
        self.grad_from_server = grad_from_server
        self.client_side_intermidiate.backward(grad_from_server)

    def train(self):
        self.client_model.train()

    def eval(self):
        self.client_model.eval()


class Server(torch.nn.Module):
    def __init__(self, server_model):
        super().__init__()
        """class that expresses the Server on SplitNN
        Args:
            server_model (torch model): server-side model
        Attributes:
            server_model (torch model): server-side model
            intermidiate_to_server:
            grad_to_client
        """
        self.server_model = server_model

        self.intermidiate_to_server = None
        self.grad_to_client = None

    def forward(self, intermidiate_to_server):
        """server-side training
        Args:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model
        Returns:
            outputs (torch.Tensor): outputs of server-side model
        """
        self.intermidiate_to_server = intermidiate_to_server
        outputs = self.server_model(intermidiate_to_server)

        return outputs

    def server_backward(self):
        self.grad_to_client = self.intermidiate_to_server.grad.clone()
        return self.grad_to_client

    def train(self):
        self.server_model.train()

    def eval(self):
        self.server_model.eval()

def max_norm(grad):
    """server-side heuristic approach to prevent label leakage attacks
       https://arxiv.org/abs/2102.08504
    Args:
        grad (torch.Tensor): the gradient of L with respect to the
                             input of the function h
                             ---
                                L : the loss function
                                f : the client side model
                                h : the server side model
                                the whole model can be expressed as h ◦ f
    Returns:
        pertubated_gard (torch.Tensor): noised gradient which is
                                        supposed to be sent to the client
    Examples:
        # you have to create custom Server class
        class Server_with_max_norm(Server):
            def __init__(self, server_model,
                        server_optimizer,
                        criterion):
                super().__init__(server_model,
                                server_optimizer,
                                criterion)
            def _fit_server(self, intermidiate_to_server, labels):
                outputs = self.server_model(intermidiate_to_server)
                loss = self.criterion(outputs, labels)
                loss.backward()
                grad_to_client = intermidiate_to_server.grad.clone()
                grad_to_client = max_norm(grad_to_client)
                return outputs, loss, grad_to_client
    """

    g_norm = grad.pow(2).sum(dim=list(range(1, len(grad.shape)))).sqrt()
    # maximum gradient norm among the mini-batch
    g_max = g_norm[torch.argmax(g_norm)]
    # the standard deviation to be determined
    sigma = torch.sqrt(g_max / g_norm - 1)
    # gausiaan noise
    perturbation = torch.normal(torch.zeros_like(sigma), sigma)
    # expand dimension
    perturbation = perturbation.expand(list(grad.shape)[::-1]).T
    # perturbed gradient
    pertubated_gard = grad + perturbation

    return pertubated_gard


class Server_with_max_norm(Server):
    def __init__(self, server_model):
        super().__init__(server_model)

    def backward(self):
        grad_to_client = self.intermidiate_to_server.grad.clone()
        self.grad_to_client = max_norm(grad_to_client)
        return self.grad_to_client

class SplitNN(torch.nn.Module):
    def __init__(self, client, server,
                 client_optimizer, server_optimizer
                 ):
        super().__init__()
        """class that expresses the whole architecture of SplitNN
        Args:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        Attributes:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        """
        self.client = client
        self.server = server
        self.client_optimizer = client_optimizer
        self.server_optimizer = server_optimizer
        self.grad_to_client = None

        self.intermidiate_to_server = None

    def forward(self, inputs, labels):
        # execute client - feed forward network
        self.labels = labels
        self.intermidiate_to_server = self.client(inputs)
        # execute server - feed forward netwoek
        outputs = self.server(self.intermidiate_to_server)
        # grad_to_client = self.server.server_backward(self.intermidiate_to_server)
        # grad_to_client = self.server.server_backward()

        return outputs, self.intermidiate_to_server

    def backward(self):
        # execute server - back propagation
        self.grad_to_client = self.server.server_backward()
        # execute client - back propagation
        # if model=='Marvell':
        #   grad_to_client=KL_gradient_perturb_function_creator(self.label,grad_to_client)

        self.client.client_backward(self.grad_to_client)

    def zero_grads(self):
        self.client_optimizer.zero_grad()
        self.server_optimizer.zero_grad()

    def step(self):
        self.client_optimizer.step()
        self.server_optimizer.step()

    def train(self):
        self.client.train()
        self.server.train()

    def eval(self):
        self.client.eval()
        self.server.eval()


def train_maxnorm(Epochs,features,train_loader,test_loader,lr=1e-4, info=True):
    input_dim = features.shape[-1]
    model_1 = FirstNet(input_dim)
    model_1 = model_1.to(device)

    model_2 = SecondNet()
    model_2 = model_2.to(device)

    model_1.double()
    model_2.double()

    opt_1 = optim.Adam(model_1.parameters(), lr=lr)
    opt_2 = optim.Adam(model_2.parameters(), lr=lr)

    BCE = nn.BCELoss()

    client_maxnorm = Client(model_1)
    server_maxnorm = Server_with_max_norm(model_2)

    splitnn_maxnorm = SplitNN(client_maxnorm, server_maxnorm, opt_1, opt_2)
    splitnn_maxnorm.train()
    for epoch in range(Epochs):
        epoch_loss = 0
        epoch_outputs = []
        epoch_labels = []
        epoch_outputs_test = []
        epoch_labels_test = []
        epoch_g_norm = []
        epoch_g_mean = []
        epoch_g_inner = []
        epoch_g = []
        for i, data in enumerate(train_loader):
            splitnn_maxnorm.zero_grads()

            inputs, labels = data
            inputs = inputs.to(device).double()
            labels = labels.to(device).double()

            outputs, intermidiate_to_server = splitnn_maxnorm(inputs, labels)
            loss = BCE(outputs, labels)

            loss.backward()

            splitnn_maxnorm.backward()
            splitnn_maxnorm.step()
            # loss_D.backward()
            # opt_D.step()

            epoch_loss += (loss).item() / len(train_loader.dataset)
            epoch_outputs.append(outputs)
            epoch_labels.append(labels)

            grad_from_server = splitnn_maxnorm.grad_to_client
            g = list(grad_from_server.detach().numpy())
            g_norm = grad_from_server.pow(2).sum(dim=1).sqrt()
            v_1 = np.multiply(grad_from_server.detach().numpy(), labels.detach().numpy())
            mean_1 = v_1.sum() / len(v_1[v_1 != 0])
            mean_0 = (grad_from_server.detach().numpy().sum() -
                      v_1.sum()) / len(v_1[v_1 == 0])
            # print(mean_1,mean_0)
            g_mean = []
            for a in g:
                if (a - mean_1) ** 2 < (a - mean_0) ** 2:
                    g_mean.append([1])
                else:
                    g_mean.append([0])
            g_mean = grad_from_server  # torch.tensor(g_mean)

            g_inner = []
            for a in g:
                if a > grad_from_server.median().item():
                    g_inner.append(1)
                else:
                    g_inner.append(0)
            g_inner = torch.tensor(g_inner)

            epoch_g_norm.append(g_norm)
            epoch_g_mean.append(g_mean)
            epoch_g_inner.append(g_inner)
            epoch_g.append(grad_from_server)

            t = next(iter(test_loader))
            outputs_test, _ = splitnn_maxnorm(t[0], t[1])
            labels_test = t[1]
            epoch_outputs_test.append(outputs_test)
            epoch_labels_test.append(labels_test)

            # for  p in discriminator.parameters():
            #   print(epoch,p)

        # print(epoch_g_norm)

        train_auc = torch_auc(torch.cat(epoch_labels),
                              torch.cat(epoch_outputs))
        test_auc = torch_auc(torch.cat(epoch_labels_test),
                             torch.cat(epoch_outputs_test))
        train_tvd = totalvaraition(torch.cat(epoch_labels),
                                   torch.cat(epoch_g))
        na_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_norm).view(-1, 1)))
        ma_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_mean).view(-1, 1)))
        cos_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_inner).view(-1, 1)))
        if info == True and (epoch % 10 == 0 or epoch == Epochs - 1):
            print('Epoch', epoch, 'Training Loss', epoch_loss,
                  'Training AUC', train_auc,
                  'Testing AUC', test_auc,
                  'NA Leak AUC', na_leak_auc,
                  'MA Leak AUC', ma_leak_auc,
                  'Median Leak AUC', cos_leak_auc
                  )
    return train_auc, test_auc, train_tvd, na_leak_auc, ma_leak_auc, cos_leak_auc, splitnn_maxnorm


#Multiple
class SplitNN_multiple(torch.nn.Module):
    def __init__(self, clients, server,
                 clients_optimizers, server_optimizer, features
                 ):
        super().__init__()
        """class that expresses the whole architecture of SplitNN
        Args:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        Attributes:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        """
        self.clients = clients
        self.number = len(clients)
        self.server = server
        self.client_optimizers = clients_optimizers
        self.server_optimizer = server_optimizer
        self.grad_to_client = None
        self.intermidiate_to_server = 0
        self.features = features

    def forward(self, inputs, labels):

        intermidiate_to_servers = []
        # execute client - feed forward network
        self.labels = labels
        inter = int(inputs.shape[1] / self.number)
        # print('inputs.shape[0]',inputs.shape[1])
        # print('self.number',self.number)
        # print('inter',inter)
        self.intermidiate_to_server = 0
        for i in range(self.number):
            client = self.clients[i]
            input_data = inputs[:, self.features[i]:self.features[i + 1]]
            # input_data=inputs[:,i*inter:(i+1)*inter]

            # if i==self.number-1:
            #   # print('Last Client',i*inter,inputs.shape[1])
            #   input_data=inputs[:,i*inter:]
            # else:
            #   input_data=inputs[:,i*inter:(i+1)*inter]
            # print(i,self.number-1,client,input_data.shape,i*inter,(i+1)*inter)
            # print('inputs[:,i*inter:(i+1)*inter]',inputs[:,i*inter:(i+1)*inter])
            v = (client(input_data) / self.number)
            # print('v',v.shape)
            # print('client(inputs[:,i*inter:(i+1)*inter])/self.number',client(inputs[:,i*inter:(i+1)*inter])/self.number)
            # print('self.intermidiate_to_server',self.intermidiate_to_server)
            intermidiate_to_servers.append(v)
            self.intermidiate_to_server += v
        self.intermidiate_to_server.retain_grad()
        # execute server - feed forward netwoek .detach()
        # print('self.intermidiate_to_server',self.intermidiate_to_server)
        intermidiate_to_server = self.intermidiate_to_server
        # print('intermidiate_to_server.detach()',intermidiate_to_server.detach())
        outputs = self.server(self.intermidiate_to_server)
        # grad_to_client = self.server.server_backward(self.intermidiate_to_server)
        # grad_to_client = self.server.server_backward()

        return outputs, self.intermidiate_to_server, intermidiate_to_servers

    def backward(self):
        # execute server - back propagation
        self.grad_to_client = self.server.server_backward()
        # execute client - back propagation
        # if model=='Marvell':
        #   grad_to_client=KL_gradient_perturb_function_creator(self.label,grad_to_client)

        for i in range(self.number):
            client = self.clients[i]
            client.client_backward(self.grad_to_client / self.number)

    def zero_grads(self):
        for i in range(self.number):
            self.client_optimizers[i].zero_grad()
            # self.client_optimizer.zero_grad()
        self.server_optimizer.zero_grad()

    def step(self):
        for i in range(self.number):
            self.client_optimizers[i].step()
        self.server_optimizer.step()

    def train(self):
        for i in range(self.number):
            self.clients[i].train()
        self.server.train()

    def eval(self):
        for i in range(self.number):
            self.clients[i].eval()
        self.server.eval()

def train_maxnorm_multiple(Epochs, features,train_loader,test_loader,lr=1e-4,info=False):
    model_client_1 = FirstNet(input_dim=features[1] - features[0])
    model_client_1 = model_client_1.to(device)
    model_client_2 = FirstNet(input_dim=features[2] - features[1])
    model_client_2 = model_client_2.to(device)
    model_client_3 = FirstNet(input_dim=features[3] - features[2])
    model_client_3 = model_client_3.to(device)

    model_client_1.double()
    model_client_2.double()
    model_client_3.double()
    model_clients = [model_client_1, model_client_2, model_client_3]

    client1 = Client(model_client_1)
    client2 = Client(model_client_2)
    client3 = Client(model_client_3)
    client = [client1, client2, client3]

    opt_c_1 = optim.Adam(model_client_1.parameters(), lr=lr)
    opt_c_2 = optim.Adam(model_client_2.parameters(), lr=lr)
    opt_c_3 = optim.Adam(model_client_3.parameters(), lr=lr)
    opt_c = [opt_c_1, opt_c_2, opt_c_3]

    model_2 = SecondNet()
    model_2 = model_2.to(device)
    model_2.double()
    opt_2 = optim.Adam(model_2.parameters(), lr=lr)
    server = Server_with_max_norm(model_2)

    BCE = nn.BCELoss()
    splitnn = SplitNN_multiple(client, server, opt_c, opt_2, features)
    training_labels = []
    outputs_list = []
    intermediate_servers = []
    train_auc_list = []
    test_auc_list = []
    grads_vanilla = []
    na_leak_auc_list = []
    ma_leak_auc_list = []
    cos_leak_auc_list = []
    train_tvd_list = []
    splitnn.train()
    for epoch in range(Epochs):
        epoch_loss = 0
        epoch_outputs = []
        epoch_labels = []
        epoch_outputs_test = []
        epoch_labels_test = []
        epoch_g = []
        epoch_g_inner = []
        epoch_g_mean = []
        epoch_g_norm = []
        epoch_g1 = []
        epoch_g_inner1 = []
        epoch_g_mean1 = []
        epoch_g_norm1 = []
        epoch_g2 = []
        epoch_g_inner2 = []
        epoch_g_mean2 = []
        epoch_g_norm2 = []
        epoch_g3 = []
        epoch_g_inner3 = []
        epoch_g_mean3 = []
        epoch_g_norm3 = []
        for i, data in enumerate(train_loader):
            splitnn.zero_grads()

            inputs, labels = data
            inputs = inputs.to(device).double()
            labels = labels.to(device).double()

            outputs, intermidiate_to_server, intermidiate_to_servers = splitnn(inputs, labels)
            loss = BCE(outputs, labels)

            loss.backward(retain_graph=True)

            splitnn.backward()
            splitnn.step()
            # loss_D.backward()
            # opt_D.step()

            epoch_loss += (loss).item() / len(train_loader.dataset)
            epoch_outputs.append(outputs)
            epoch_labels.append(labels)
            g_norm, g_mean, g_inner = Attacks(splitnn.grad_to_client, labels)
            epoch_g_norm.append(g_norm)
            epoch_g_mean.append(g_mean)
            epoch_g_inner.append(g_inner)
            epoch_g.append(splitnn.grad_to_client)

            g_norm1, g_mean1, g_inner1 = Attacks(splitnn.grad_to_client / 3, labels)
            epoch_g_norm1.append(g_norm1)
            epoch_g_mean1.append(g_mean1)
            epoch_g_inner1.append(g_inner1)
            epoch_g1.append(intermidiate_to_server[0])

            g_norm2, g_mean2, g_inner2 = Attacks(splitnn.grad_to_client / 3, labels)
            epoch_g_norm2.append(g_norm2)
            epoch_g_mean2.append(g_mean2)
            epoch_g_inner2.append(g_inner2)
            epoch_g2.append(intermidiate_to_server[1])

            g_norm3, g_mean3, g_inner3 = Attacks(splitnn.grad_to_client / 3, labels)
            epoch_g_norm3.append(g_norm3)
            epoch_g_mean3.append(g_mean3)
            epoch_g_inner3.append(g_inner3)
            epoch_g3.append(intermidiate_to_server[2])

            t = next(iter(test_loader))
            outputs_test, _, _ = splitnn(t[0], t[1])
            labels_test = t[1]
            epoch_outputs_test.append(outputs_test)
            epoch_labels_test.append(labels_test)

        train_auc = torch_auc(torch.cat(epoch_labels),
                              torch.cat(epoch_outputs))
        test_auc = torch_auc(torch.cat(epoch_labels_test),
                             torch.cat(epoch_outputs_test))
        train_tvd = totalvaraition(torch.cat(epoch_labels),
                                   torch.cat(epoch_g))

        # train_auc=max(torch_auc(torch.cat(epoch_labels),
        #                             torch.cat(epoch_outputs)),1-torch_auc(torch.cat(epoch_labels),
        #                             torch.cat(epoch_outputs)))
        # test_auc=max(torch_auc(torch.cat(epoch_labels_test),
        #                             torch.cat(epoch_outputs_test)),1-torch_auc(torch.cat(epoch_labels_test),
        #                             torch.cat(epoch_outputs_test)))
        na_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_norm).view(-1, 1)))
        ma_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_mean).view(-1, 1)))
        cos_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_inner).view(-1, 1)))

        na_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm1).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm1).view(-1, 1)))
        ma_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean1).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean1).view(-1, 1)))
        cos_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner1).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner1).view(-1, 1)))

        na_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm2).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm2).view(-1, 1)))
        ma_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean2).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean2).view(-1, 1)))
        cos_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner2).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner2).view(-1, 1)))

        na_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm3).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm3).view(-1, 1)))
        ma_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean3).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean3).view(-1, 1)))
        cos_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner3).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner3).view(-1, 1)))
        train_auc_list.append(train_auc)
        test_auc_list.append(test_auc)
        train_tvd_list.append(train_tvd)
        na_leak_auc_list.append(na_leak_auc)
        ma_leak_auc_list.append(ma_leak_auc)
        cos_leak_auc_list.append(cos_leak_auc)

        if (epoch % 10 == 0 or epoch == Epochs - 1):
            print('Epoch', epoch, 'Training Loss', epoch_loss,
                  'Training AUC', train_auc,
                  'Testing AUC', test_auc,
                  'TVD', train_tvd,
                  'NA Leak AUC', na_leak_auc,
                  'MA Leak AUC', ma_leak_auc,
                  'Cos Leak AUC', cos_leak_auc
                  )
            print('Client1',
                  'NA Leak AUC', na_leak_auc1,
                  'MA Leak AUC', ma_leak_auc1,
                  'Cos Leak AUC', cos_leak_auc1
                  )
            print('Client2',
                  'NA Leak AUC', na_leak_auc2,
                  'MA Leak AUC', ma_leak_auc2,
                  'Cos Leak AUC', cos_leak_auc2
                  )
            print('Client3',
                  'NA Leak AUC', na_leak_auc3,
                  'MA Leak AUC', ma_leak_auc3,
                  'Cos Leak AUC', cos_leak_auc3
                  )
        training_labels.append(labels)
        outputs_list.append(outputs)
        intermediate_servers.append(intermidiate_to_server)
        grads_vanilla.append(splitnn.grad_to_client)
    return train_auc, test_auc, train_tvd, na_leak_auc, ma_leak_auc, cos_leak_auc, na_leak_auc1, ma_leak_auc1, cos_leak_auc1, na_leak_auc2, ma_leak_auc2, cos_leak_auc2, na_leak_auc3, ma_leak_auc3, cos_leak_auc3, splitnn
