from __future__ import print_function

from copy import deepcopy

import torch
import torch.nn.functional as F
from IPython import embed
from utils import utils
from clients_attackers import Attacker_Text, Text_Backdoor_Utils
from utils.gptlm import GPT2LM
import time

class Server():
    def __init__(self, args, model, dataLoader, text_backdoor_utils, criterion=F.nll_loss, device='cpu'):
        self.clients = []
        self.args = args
        self.model = model
        self.dataLoader = dataLoader
        self.device = device
        self.emptyStates = None
        self.init_stateChange()
        self.Delta = None
        self.utils = text_backdoor_utils
        self.iter = 0
        self.AR = self.FedAvg
        self.func = torch.mean
        self.isSaveChanges = False
        self.savePath = './AggData'
        self.criterion = criterion
        self.path_to_aggNet = ""
        if self.args.defense == 'ONION':
            self.LM = GPT2LM(use_tf=False, device='cuda' if torch.cuda.is_available() else 'cpu')
        self.UNK = self.utils.vocab.stoi['<unk>']
        self.PAD = self.utils.vocab.stoi['<pad>']
        self.replace_word = len(self.utils.vocab.itos) - 1
        if 'His' in args.AR:
            self.history = torch.zeros(args.num_clients, args.num_clients) # with history
        else:
            self.history = None # no history
            
    def init_stateChange(self):
        states = deepcopy(self.model.state_dict())
        for param, values in states.items():
            values *= 0
        self.emptyStates = states

    def attach(self, c):
        self.clients.append(c)

    def distribute(self):
        for c in self.clients:
            c.setModelParameter(self.model.state_dict())
            c.Delta_server = self.Delta
            
    def test(self, steps):
        print("[Server] Start testing")
        self.model.to(self.device)
        self.model.eval()
        test_loss = 0
        correct = 0
        count = 0
        with torch.no_grad():
            for data, target in self.dataLoader:
                data, target = data.to(self.device), target.to(self.device)
                if steps == self.args.epochs:
                    if self.args.defense == 'filter':
                        data = self.filter_word(data)
                    elif self.args.defense == 'ONION':
                        data = self.ONION(data)
                    elif self.args.defense == 'RAP':
                        data = self.RAP(data)
                    elif self.args.defense != 'none':
                        assert(False)
                output = self.model(data)
                test_loss += self.criterion(output, target, reduction='sum').item()  # sum up batch loss
                if output.dim() == 1:
                    pred = torch.round(torch.sigmoid(output))
                else:
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
                count += pred.shape[0]
        test_loss /= count
        accuracy = 100. * correct / count
        self.model.cpu()  ## avoid occupying gpu when idle
        print('[Server] Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(test_loss, correct, count, accuracy))
        return test_loss, accuracy

    def test_backdoor(self, steps):
        print("[Server] Start testing backdoor\n")
        self.model.to(self.device)
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            ###from tqdm import tqdm
            ###for data, target in tqdm(list(self.dataLoader)):
            for data, target in self.dataLoader:
                data, target = self.utils.get_poison_batch(data, target, backdoor_fraction=1, backdoor_label=self.utils.backdoor_label, evaluation=True)
                data, target = data.to(self.device), target.to(self.device)
                if steps == self.args.epochs:
                    if self.args.defense == 'filter':
                        data = self.filter_word(data)
                    elif self.args.defense == 'ONION':
                        data = self.ONION(data)
                    elif self.args.defense == 'RAP':
                        data = self.RAP(data)
                    elif self.args.defense != 'none':
                        assert(False)
                output = self.model(data)
                test_loss += self.criterion(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += pred.view(-1).size(0)
                
        test_loss /= len(self.dataLoader.dataset)
        accuracy = 100. * correct / total

        self.model.cpu()  ## avoid occupying gpu when idle
        print('[Server] Test set (Backdoored): Average loss: {:.4f}, Success rate: {}/{} ({:.3f}%)\n'.format(test_loss, correct, total, accuracy))
        return test_loss, accuracy
        
    def train(self, group):
        selectedClients = [self.clients[i] for i in group]
        tic = time.perf_counter()
        for c in selectedClients:
            c.train()
            c.update()
        toc = time.perf_counter()
        print(f"[Server] The training takes {toc - tic:0.6f} seconds.\n")

        if self.isSaveChanges:
            self.saveChanges(selectedClients)
            
        tic = time.perf_counter()
        Delta = self.AR(selectedClients)
        toc = time.perf_counter()
        print(f"[Server] The aggregation takes {toc - tic:0.6f} seconds.\n")
        
        for param in self.model.state_dict():
            self.model.state_dict()[param] += Delta[param]
        self.Delta = Delta
        self.iter += 1

    def saveChanges(self, clients):

        Delta = deepcopy(self.emptyStates)
        deltas = [c.getDelta() for c in clients]

        param_trainable = utils.getTrainableParameters(self.model)

        param_nontrainable = [param for param in Delta.keys() if param not in param_trainable]
        for param in param_nontrainable:
            del Delta[param]
        print(f"[Server] Saving the model weight of the trainable paramters:\n {Delta.keys()}")
        for param in param_trainable:
            ##stacking the weight in the innerest dimension
            param_stack = torch.stack([delta[param] for delta in deltas], -1)
            shaped = param_stack.view(-1, len(clients))
            Delta[param] = shaped

        saveAsPCA = True
        saveOriginal = False
        if saveAsPCA:
            from utils import convert_pca
            proj_vec = convert_pca._convertWithPCA(Delta)
            savepath = f'{self.savePath}/pca_{self.iter}.pt'
            torch.save(proj_vec, savepath)
            print(f'[Server] The PCA projections of the update vectors have been saved to {savepath} (with shape {proj_vec.shape})')
        if saveOriginal:
            savepath = f'{self.savePath}/{self.iter}.pt'
            torch.save(Delta, savepath)
            print(f'[Server] Update vectors have been saved to {savepath}')

    ## Aggregation functions ##
    def set_AR(self, ar):
        if 'crfl' in ar:
            self.sigma = 0.01
            self.rou_k = 0.05
            self.rou_b = 2
            if ('-' in ar):
                if len(ar.split('-')) == 4:
                    ar, sigma, rou_k, rou_b = ar.split('-')
                    self.sigma = float(sigma)
                    if (rou_k == '0') and (rou_b == '0'):
                        self.rou_k = -1
                        self.rou_b = -1
                    else:
                        self.rou_k = float(rou_k)
                        self.rou_b = float(rou_b)
                else:
                    assert(False)
        else:
            self.sigma = '0'
            self.K_div = 1000
            self.rou_k = -1
            self.rou_b = -1
            if ('-' in ar):
                if len(ar.split('-')) == 3:
                    ar, sigma, K_div = ar.split('-')
                    self.sigma = str(sigma)
                    self.K_div = int(K_div)  
                elif len(ar.split('-')) == 5:
                    ar, sigma, K_div, rou_k, rou_b = ar.split('-')
                    self.sigma = str(sigma)
                    self.K_div = int(K_div)    
                    self.rou_k = float(rou_k)
                    self.rou_b = float(rou_b)
                else:
                    assert(False)
        if ar == 'fedavg':
            self.AR = self.FedAvg
        elif ar == 'median':
            self.AR = self.FedMedian
        elif ar == 'gm':
            self.AR = self.geometricMedian
        elif ar == 'krum':
            self.AR = self.krum
        elif ar == 'mkrum':
            self.AR = self.mkrum
        elif ar == 'bulyan':
            self.AR = self.bulyan
        elif ar == 'foolsgold':
            self.AR = self.foolsGold
        elif ar == 'crfl':
            self.AR = self.CRFL
        elif ar == 'residualbase':
            self.AR = self.residualBase
        elif ar == 'attention':
            self.AR = self.net_attention
        elif ar == 'mlp':
            self.AR = self.net_mlp
        elif ar in ['element', 'elementHis']:
            self.AR = self.element
        elif ar in ['neuro', 'neuroHis']:
            self.AR = self.neuro
        elif ar in ['melement', 'melementHis']:
            self.AR = self.melement
        elif ar in ['delement', 'delementHis']:
            self.AR = self.delement
        elif ar in ['lelement', 'lelementHis']:
            self.AR = self.lelement
        elif ar in ['mdis', 'mdisHis']:
            self.AR = self.mdis
        elif ar in ['ddis', 'ddisHis']:
            self.AR = self.ddis
        elif ar in ['ldis', 'ldisHis']:
            self.AR = self.ldis
        elif ar in ['dneuro', 'dneuroHis']:
            self.AR = self.dneuro
        elif ar in ['mneuro', 'mneuroHis']:
            self.AR = self.mneuro
        elif ar in ['belement', 'belementHis']:
            self.AR = self.belement
        elif ar in ['bneuro', 'bneuroHis']:
            self.AR = self.bneuro
        else:
            raise ValueError("Not a valid aggregation rule or aggregation rule not implemented")

    def FedAvg(self, clients):
        out = self.FedFuncWholeNet(clients, lambda arr: torch.mean(arr, dim=-1, keepdim=True))
        return out

    def FedMedian(self, clients):
        out = self.FedFuncWholeNet(clients, lambda arr: torch.median(arr, dim=-1, keepdim=True)[0])
        return out

    def geometricMedian(self, clients):
        from rules.geometricMedian import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net().cpu()(arr.cpu()))
        return out

    def CRFL(self, clients):
        from rules.CRFL import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net().cpu()(arr.cpu(), self.step, self.epochs, self.sigma, self.rou_k, self.rou_b))
        return out

    def krum(self, clients):
        from rules.multiKrum import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net('krum').cpu()(arr.cpu()))
        return out    

    def mkrum(self, clients):
        from rules.multiKrum import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net('mkrum').cpu()(arr.cpu()))
        return out

    def bulyan(self, clients):
        from rules.bulyan import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net().cpu()(arr.cpu()))
        return out
        
    def element(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'element', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='krum', history=self.history))
        return out
        
    def delement(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'element', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='dkrum', history=self.history))
        return out
        
    def lelement(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'element', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='lkrum', history=self.history))
        return out
        
    def mdis(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'element', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='mdis', history=self.history))
        return out
        
    def ddis(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'element', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='ddis', history=self.history))
        return out
        
    def ldis(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'element', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='ldis', history=self.history))
        return out
        
    def melement(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'element', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='mkrum', history=self.history))
        return out
        
    def belement(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'element', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='bulyan', history=self.history))
        return out
        
    def neuro(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'neuro', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='krum', history=self.history))
        return out
        
    def mneuro(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'neuro', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='mkrum', history=self.history))
        return out
        
    def dneuro(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'neuro', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='dkrum', history=self.history))
        return out
        
    def bneuro(self, clients):
        from rules.pointWise import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net(self.model, self.args.vocab).cpu()(arr.cpu(), 'neuro', self.step, self.epochs, self.sigma, self.K_div, self.rou_k, self.rou_b, mode='bulyan', history=self.history))
        return out
        
    def foolsGold(self, clients):
        from rules.foolsGold import Net
        self.Net = Net
        out = self.FedFuncWholeNet(clients, lambda arr: Net().cpu()(arr.cpu()))
        return out

    def residualBase(self, clients):
        from rules.residualBase import Net
        out = self.FedFuncWholeStateDict(clients, Net().main)
        return out

    def net_attention(self, clients):
        from aaa.attention import Net

        net = Net()
        net.path_to_net = self.path_to_aggNet

        out = self.FedFuncWholeStateDict(clients, lambda arr: net.main(arr, self.model))
        return out

    def net_mlp(self, clients):
        from aaa.mlp import Net

        net = Net()
        net.path_to_net = self.path_to_aggNet

        out = self.FedFuncWholeStateDict(clients, lambda arr: net.main(arr, self.model))
        return out

        ## Helper functions, act as adaptor from aggregation function to the federated learning system##

    def FedFuncWholeNet(self, clients, func):
        '''
        The aggregation rule views the update vectors as stacked vectors (1 by d by n).
        '''
        Delta = deepcopy(self.emptyStates)
        deltas = [c.getDelta() for c in clients]
        vecs = [utils.net2vec(delta) for delta in deltas]
        vecs = [vec for vec in vecs if torch.isfinite(vec).all().item()]
        result = func(torch.stack(vecs, 1).unsqueeze(0))  # input as 1 by d by n
        result = result.view(-1)
        utils.vec2net(result, Delta)
        return Delta

    def FedFuncWholeStateDict(self, clients, func):
        '''
        The aggregation rule views the update vectors as a set of state dict.
        '''
        Delta = deepcopy(self.emptyStates)
        deltas = [c.getDelta() for c in clients]
        # sanity check, remove update vectors with nan/inf values
        deltas = [delta for delta in deltas if torch.isfinite(utils.net2vec(delta)).all().item()]

        resultDelta = func(deltas)

        Delta.update(resultDelta)
        return Delta

    def filter_word(self, data):
        data[data > 5000] = self.UNK
        return data
        
    def ONION(self, data):
        thr = -0.1
        sents = [self.utils.item_to_sent(data[i]) for i in range(data.size(0))]
        all_PPL = self.get_PPL(sents)
        for i in range(data.size(0)):
            PPL = all_PPL[i][:-1]
            PPL = torch.tensor(PPL).cuda() - all_PPL[i][-1]
            if PPL.size(0) < data.size(1):
                PPL = torch.cat((PPL, torch.zeros(data.size(1) - PPL.size(0)).cuda()), dim=0)
            data[i, PPL < thr] = self.UNK
        return data
        
    def RAP(self, data):
        thr = 0.5
        all_score = self.get_score(data)
        all_score = torch.cat(all_score, dim=0)
        delta_score = (all_score[:-1] - all_score[-1:]).abs().max(dim=-1)[0]
        if delta_score.size(0) < data.size(1):
            delta_score = torch.cat((delta_score, torch.zeros(data.size(1) - delta_score.size(0), delta_score.size(1)).cuda()), dim=0)
        delta_score = delta_score.t()
        data[delta_score > thr] = self.UNK
        return data
    
    def get_PPL(self, sents):
        def filter_sent(split_sent, pos):
            words_list = split_sent[: pos] + split_sent[pos + 1:]
            return ' '.join(words_list)
        all_PPL = []
        for sent in sents:
            split_sent = sent.split(' ')
            sent_length = len(split_sent)
            single_sent_PPL = []
            for j in range(sent_length+1):
                processed_sent = filter_sent(split_sent, j)
                single_sent_PPL.append(self.LM(processed_sent))
            all_PPL.append(single_sent_PPL)
        assert len(all_PPL) == len(sents)
        return all_PPL
    
    def get_score(self, data):
        sent_score = []
        sent_length = data.ne(self.PAD).sum(dim=1).max()
        for j in range(sent_length+1):
            if j != sent_length:
                tmp = data[:, j].clone()
                data[:, j] = self.replace_word
            outputs = self.model(data)
            p = outputs.exp() / outputs.exp().sum(dim=1, keepdim=True)
            sent_score.append(p.view(1, data.size(0), -1))
            if j != sent_length:
                data[:, j] = tmp
        return sent_score