import socket
HOST = '127.0.0.1'  # Standard loopback interface address (localhost)
PORT = 65432  # Port to listen on (non-privileged ports are > 1023)

import torch
import numpy as np
import random
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import datasets, transforms
cuda = torch.cuda.is_available()
from analysis.private_knn import PrivateKnn
import analysis
from pow.hashcash import mint_iteractive, generate_challenge, check, _to_binary
from pow.proof_of_work import PoW
import dfmenetwork
import scipy
import scipy.stats
import math
import pickle
import time
from models.ensemble_model import EnsembleModel
from models.load_models import load_private_models
from models.load_models import load_victim_model
from models.private_model import get_private_model_by_id
from parameters import get_parameters
import argparse


parser = argparse.ArgumentParser(description='Server Setup')
parser.add_argument('--dataset', default='mnist', type=str)
parser.add_argument('--mode', default='other', type=str, help="select type of attack being used (dfme or other)")
# TODO: Can we remove the mode parameter and have a unified setup?
args = parser.parse_args()
args2 = get_parameters() # original parameters

DAY1 = 60 * 60 * 24  # Seconds in a day

# Functions to compute privacy cost of the queries asked

def computeentropy(t_logits):
    num_classes = 10  # Change if needed
    entropy = []
    prob = F.softmax(t_logits, dim=1).cpu().detach().numpy()
    entropy.append(scipy.stats.entropy(prob, axis=1))
    entropy = np.concatenate(entropy, axis=0)
    entropy_max = np.log(10)
    utility = entropy / entropy_max
    return utility

def computegap(t_logits):
    gap = []
    sorted_output = t_logits.sort(dim=-1, descending=True)[0]
    prob = F.softmax(sorted_output[:, :2], dim=1).cpu().detach().numpy()
    gap.append(prob[:, 0] - prob[:, 1])
    gap = np.concatenate(gap, axis=0)
    utility = 1 - gap
    return utility

def get_votes_for_pate_knn(model, t_logits, train_represent,
                           train_labels, args = None):
    """
    :param model: the model to be used
    :param unlabeled_loader: data points to be labeled - for which we compute
        the score
    :param train_represent: last layer representation for the teachers
    :param train_labels: labels for the teachers
    :param args: the program parameters
    :return: votes for each data point
    """

    # num_teachers: number of k nearest neighbors acting as teachers
    num_teachers = 300

    with torch.no_grad():
        # Privacy cost as a proxy for utility.
        votes = []
        predictions = []
        outputs = F.log_softmax(t_logits, dim=-1)
        outputs = outputs.cpu().numpy()
        predictions.append(np.argmax(outputs, axis=-1))
        for output in outputs:
            dis = np.linalg.norm(train_represent - output, axis=-1)
            k_index = np.argpartition(dis, kth=num_teachers)[:num_teachers]
            teachers_preds = np.array(train_labels[k_index], dtype=np.int32)
            label_count = np.bincount(
                teachers_preds, minlength=10)
            votes.append(label_count)
    votes = np.stack(votes)
    return votes


class PateKNN:
    """
    Compute the privacy cost.
    """

    def __init__(self, model, trainloader, args):
        """
        Args:
            model: the victim model.
            trainloader: the data loader for the training data.
            args: the program parameters.
        """
        self.model = model
        self.args = args

        # Extract the last layer representation of the training points and their
        # ground-truth labels.
        train_represent = []
        train_labels = []
        with torch.no_grad():
            for batch_id, (data, target) in enumerate(trainloader):
                if cuda:
                    data = data.cuda()
                outputs = model(data)
                outputs = F.log_softmax(outputs, dim=-1)
                outputs = outputs.cpu().numpy()
                train_represent.append(outputs)
                train_labels.append(target.cpu().numpy())
        self.train_represent = np.concatenate(train_represent, axis=0)
        self.train_labels = np.concatenate(train_labels, axis=0)

        self.private_knn = PrivateKnn(
            delta=1e-5, sigma_gnmax=28,
            apply_data_independent_bound=False)

    def compute_privacy_cost(self, t_logits):
        """
        Args:
            unlabeled_loader: data loader for new queries.
        Returns:
            The total privacy cost incurred by all the queries seen so far.
        """
        votes = get_votes_for_pate_knn(
            model=self.model, t_logits=t_logits, train_labels=self.train_labels,
            train_represent=self.train_represent, args=self.args
        )

        dp_eps = self.private_knn.add_privacy_cost(votes=votes)

        return dp_eps

print("cuda available", cuda)
dataset = args.dataset
mode = args.mode
args = None
# Initialization of PATE ensemble:
print(f"Using {dataset} dataset")
if dataset == "cifar10":
    args2.dataset = "cifar10"
    args2.begin_id = 0
    args2.end_id = 50
    args2.num_models =50
    args2.architecture = "ResNet18"
    args2.architectures = ["ResNet18"]
    args2.class_type = "multiclass"
    args2.target_model = "pate"
    args2.sigma_gnmax = 2
    args2.delta = 1e-5
    args2.private_model_path = "private-models/cifar10/ResNet18/50-models"
    args2.cuda = torch.cuda.is_available()
    args2.num_classes = 10
elif dataset == "mnist":
    args2.dataset = "mnist"
    args2.begin_id = 0
    args2.end_id = 250
    args2.num_models = 250
    args2.architecture = "MnistNetPate"
    args2.class_type = "multiclass"
    args2.target_model = "pate"
    args2.sigma_gnmax = 10
    args2.delta = 1e-5
    args2.private_model_path = "private-models/mnist/MnistNetPate/250-models"
    args2.cuda = torch.cuda.is_available()
    args2.num_classes = 10
else:
    args2.dataset = "svhn"
    args2.begin_id = 0
    args2.end_id = 250
    args2.num_models = 250
    args2.architecture = "ResNet10"
    args2.architectures = ["ResNet10"]
    args2.class_type = "multiclass"
    args2.target_model = "pate"
    args2.sigma_gnmax = 10
    args2.delta = 1e-6
    args2.private_model_path = "private-models/svhn/ResNet10/250-models"
    args2.cuda = torch.cuda.is_available()
    args2.num_classes = 10

private_models = load_private_models(args=args2,
                                             model_path=args2.private_model_path)
victim_model = EnsembleModel(model_id=-1, args=args2,
                                     private_models=private_models)
if dataset == "cifar10":
    victim = dfmenetwork.resnet_8x.ResNet34_8x(num_classes=10)
    ckpt = 'dfmodels/teacher/cifar10-resnet34_8x.pt'
    # trainloader = torch.utils.data.DataLoader(
    # torchvision.datasets.CIFAR10('/ssd003/home/akaleem/data', train=True, download=False,
    #                  transform=transforms.Compose([
    #                      transforms.ToTensor(),
    #                      transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                                           (0.2023, 0.1994, 0.2010)),
    #                  ])),
    # batch_size=64)
elif dataset == "svhn":
    victim = dfmenetwork.resnet_8x.ResNet34_8x(num_classes=10)
    ckpt = 'dfmodels/teacher/svhn-resnet34_8x.pt'
    # trainloader = torch.utils.data.DataLoader(torchvision.datasets.SVHN(
    #     root='/ssd003/home/akaleem/data',
    #     split='train',
    #     transform=transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize(
    #             (0.43768212, 0.44376972, 0.47280444),
    #             (
    #                 0.19803013, 0.20101563,
    #                 0.19703615))]),
    #     download=False), batch_size=64)
else:
    def load_private_model_by_idcopy():
        """
        Load a single model by its id.
        :param args: program parameters
        :param id: id of the model
        :return: the instance of the model
        """
        import os
        from architectures.mnist_net_pate import MnistNetPate
        filepath = "private-models/mnist/MnistNetPate/1-models/checkpoint-model(1).pth.tar"
        if os.path.isfile(filepath):
            model = MnistNetPate(name='model({:d})'.format(0 + 1), args = args)
            checkpoint = torch.load(filepath)
            model.load_state_dict(checkpoint['state_dict'])
            if cuda:
                model.cuda()
            if 'label_weights' in checkpoint and args.label_reweight is 'apply':
                model.label_weights = checkpoint['label_weights']
            model.eval()
            return model
        else:
            raise Exception(
                f"Checkpoint file {filepath} does not exist, please generate it via "
                f"train_private_models(args)!")
    if mode == "dfme":
        victim = dfmenetwork.lenet.LeNet5()
        ckpt = 'dfmodels/teacher/mnist-lenet5.pt'
        trainloader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
            root='/ssd003/home/akaleem/data/MNIST',
            train=True,
            transform=transforms.Compose([
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.13251461,),
                    (0.31,))]),
            download=False), batch_size=64)
    else:
        victim = load_private_model_by_idcopy()
        trainloader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
            root='/ssd003/home/akaleem/data/MNIST',
            train = True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.13251461,),
                    (0.31,))]),
            download=False), batch_size=64)
if dataset != "mnist" or mode == "dfme":
    if cuda:
        victim.load_state_dict(torch.load(ckpt))
        victim = victim.cuda()
    else:
        victim.load_state_dict(torch.load(ckpt), map_location=torch.device('cpu'))
victim.eval()
print("Done loading victim")
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((HOST, PORT))
print("Started server")
#s.listen(5)
i = 0
entropy_cost = 0
gap_cost = 0
pknn_cost = 0
privacy_cost = 0
# pate_knn = PateKNN(model=victim, trainloader=trainloader,
#                        args=args)
pow = PoW(dataset=dataset)
while i<=1000:
    if i % 500 == 0:
        s.listen(1)
        conn, addr = s.accept()
        print('Connected by', addr)
    i += 1
    data = []
    while True:
        packet = conn.recv(4096)
        if packet == b'doneiter':
            i = 0
            break
        #print(packet)
        if not packet or packet == b'done':
            break
        data.append(packet)
        #print("rec pack")
    #print("Done")
    if i >= 1:
        data = pickle.loads(b"".join(data))
        #print("Done unpick")
        #data = conn.recv(40960000)
        #data = pickle.loads(data)
        if dataset == "mnist" and mode != "dfme":
            data = data.reshape((-1, 1, 28, 28))
        elif mode == "dfme":
            pass
        else:
            data = data.reshape((-1, 3, 32, 32))
        data = data.to(torch.float32)
        if cuda:
            data = data.cuda()
        preds = victim(data)
        # Compute metrics
        # entropy_scores = computeentropy(preds)
        # entropy_cost += entropy_scores.sum()
        # gap_scores = computegap(preds)
        # gap_cost += gap_scores.sum()
        # pknn_cost = pate_knn.compute_privacy_cost(preds)
        # print("pknn cost", pknn_cost)
        # PATE privacy cost:
        #tdataset = [(data[a], 0) for a in range(data.size()[0])]
        tdataset = [(a, 0) for a in data]
        #print(tdataset)
        adaptive_loader = DataLoader(
            tdataset,
            batch_size=64,
            shuffle=False)

        votes_victim = victim_model.inference(adaptive_loader, args2)
        datalength = len(votes_victim)
        for i in range(datalength):
            curvote = votes_victim[i][np.newaxis, :]
            max_num_query, dp_eps, partition, answered, order_opt = analysis.analyze_multiclass_confident_gnmax(
                votes=curvote,
                threshold=0,
                sigma_threshold=0,
                sigma_gnmax=args2.sigma_gnmax,
                budget=args2.budget,
                file=None,
                delta=args2.delta,
                show_dp_budget=False,
                args=args2
            )
            # print(f'dp_eps for vote {i}: {dp_eps[0]}')
            privacy_cost += dp_eps[0]
        print('pate cost', privacy_cost)
        #print("entropy", entropy_cost)
        if mode != "dfme":
            preds = preds.cpu()

        # server

        # POW currently not being used
        # bits = pow.get_leading_zero_bits_for_challenge_through_time(
        #     privacy_cost=privacy_cost)
        # print("bits", bits)
        # do we need to use recompute_timings based on this user first?
        # recompute_timings needs access to timings and privacy costs. We could keep track of this over
        # time and then keep updating the timings as we go on?

        # xtype = 'bin'  # 'bin' or 'hex'
        # resource = 'model-extraction-warning'
        # challenge = generate_challenge(resource=resource, bits=bits)
        # challengestr = pickle.dumps(challenge)
        # conn.sendall(challengestr)
        # stamp = conn.recv(4096)
        # stamp = pickle.loads(stamp)
        #
        # is_correct = check(stamp=stamp, resource=resource, bits=bits,
        #                    check_expiration=DAY1, xtype=xtype)
        is_correct = True

        #print("is correct", is_correct)

        #print(preds)
        # ans = torch.argmax(preds).cpu()
        # ans = ans.item()
        # ans = str(ans)
        # conn.sendall(ans.encode())
        if is_correct:
            predsstr = pickle.dumps(preds)
            conn.sendall(predsstr)
            # Only for larger batch sizes
            # if mode == "dfme":
            #     time.sleep(0.01)
            #     str = "donesend"
            #     conn.sendall(str.encode())
    #i+=1
    #print(i)
conn.close()
s.close()

# This code is seperate and can be connected to from different attackers. Victim model returns logits with POW protocol.




