"""
Wu C, Wu F, Lyu L, et al. A federated graph neural network framework for privacy-preserving personalization
[J]. Nature Communications, 2022, 13(1): 3091.

TY  - JOUR
AU  - Wu, Chuhan
AU  - Wu, Fangzhao
AU  - Lyu, Lingjuan
AU  - Qi, Tao
AU  - Huang, Yongfeng
AU  - Xie, Xing
PY  - 2022
DA  - 2022/06/02
TI  - A federated graph neural network framework for privacy-preserving personalization
JO  - Nature Communications
SP  - 3091
VL  - 13
IS  - 1
AB  - Graph neural network (GNN) is effective in modeling high-order interactions and has been widely used in various personalized applications such as recommendation. However, mainstream personalization methods rely on centralized GNN learning on global graphs, which have considerable privacy risks due to the privacy-sensitive nature of user data. Here, we present a federated GNN framework named FedLightGCN for both effective and privacy-preserving personalization. Through a privacy-preserving model update method, we can collaboratively train GNN models based on decentralized graphs inferred from local data. To further exploit graph information beyond local interactions, we introduce a privacy-preserving graph expansion protocol to incorporate high-order information under privacy protection. Experimental results on six datasets for personalization in different scenarios show that FedLightGCN achieves 4.0% ~ 9.6% lower errors than the state-of-the-art federated personalization methods under good privacy protection. FedLightGCN provides a promising direction to mining decentralized graph data in a privacy-preserving manner for responsible and intelligent personalization.
SN  - 2041-1723
UR  - https://doi.org/10.1038/s41467-022-30714-9
DO  - 10.1038/s41467-022-30714-9
ID  - Wu2022
ER  - 

@article{wu2022federated,
  title={A federated graph neural network framework for privacy-preserving personalization},
  author={Wu, Chuhan and Wu, Fangzhao and Lyu, Lingjuan and Qi, Tao and Huang, Yongfeng and Xie, Xing},
  journal={Nature Communications},
  volume={13},
  number={1},
  pages={3091},
  year={2022},
  publisher={Nature Publishing Group UK London}
}

https://github.com/wuch15/FedLightGCN
"""
import copy
import numpy as np
import torch.nn as nn
import torch
import logging
from dataloaders.BaseDataLoader import *
from framework.fed.client import ClientBase
from framework.fed.server import ServerBase
from framework.modules.models import BaseModel, AE
from framework.modules.layers import MLP_Block
from framework.modules.utils import get_device
import scipy.sparse as sp
from framework.utils import calculate_model_size
from thop import profile
    
class model(BaseModel):
    def __init__(self, 
                user_num, 
                item_num, 
                embedding_dim, 
                layer_num,
                learning_rate, 
                optimizer,
                loss_fn,
                task,
                light=True,
                device=-1, embedding_regularizer=None, net_regularizer=None, metrics=None, *args, **kwargs):
        super(__class__, self).__init__(device, embedding_regularizer, net_regularizer, metrics)

        self.embedding_user = nn.Embedding(num_embeddings=user_num, embedding_dim=embedding_dim)
        self.embedding_item = nn.Embedding(num_embeddings=item_num, embedding_dim=embedding_dim)
        self.user_num = user_num
        self.item_num = item_num
        self.embedding_dim = embedding_dim
        self.layer_num = layer_num
        self.task = task
        self.output_activation= nn.Sigmoid()
        if task == "regression":
            self.predictor = MLP_Block(input_dim = embedding_dim*2,output_dim=1, output_activation="sigmoid")
            
        self.reset_parameters()
        self.__init_weight()
        self.compile(optimizer=optimizer, loss=loss_fn, lr=learning_rate)
        self.model_to_device()

    def __init_weight(self, ):
        nn.init.normal_(self.embedding_user.weight, std=0.1)
        nn.init.normal_(self.embedding_item.weight, std=0.1)

    def propagate(self, graph):
        users_emb = self.embedding_user.weight
        items_emb = self.embedding_item.weight
        all_emb = torch.cat([users_emb, items_emb])
        for _ in range(self.layer_num):
            all_emb = torch.sparse.mm(graph, all_emb)
        users, items = torch.split(all_emb, [self.user_num, self.item_num])
        return users, items
    
    def forward(self, all_users, users, items):
        user_agg = all_users[users]
        items_emb = self.embedding_item(items)
        return self.predict(user_agg, items_emb)
    
    def train_step(self, graph, users, items, labels):
        self.train()
        self.optimizer.zero_grad()
        all_users, _ = self.propagate(graph)
        pred = self.forward(all_users, users, items)
        loss =  self.loss_fn(pred, labels, ) + self.add_regularization()
        loss.backward()
        self.optimizer.step()
        return loss
    
    def train_step_triple(self, graph, users, pos, neg):
        self.train()
        self.embedding_p.requires_grad_ = False
        self.optimizer.zero_grad()
        all_users, all_items = self.propagate(graph)
        pred_pos = self.forward(all_users, users, pos)
        pred_neg = self.forward(all_users, users, neg)
        loss =  self.loss_fn(pred_pos, pred_neg, ) + self.add_regularization_triple(all_users[users], self.embedding_item(pos), self.embedding_item(neg))
        loss.backward()
        self.optimizer.step()
        return loss

    def predict(self, users_emb, items_emb):
        if self.task == "regression":
            pred = self.predictor(torch.cat([users_emb, items_emb], dim=-1)).squeeze(1)
            gamma = pred * 4.0 + 1.0
        elif self.task == "rank":
            gamma     = (users_emb*items_emb).sum(1)
            gamma = self.output_activation(gamma)
        else:
            gamma     = (users_emb*items_emb).sum(1)
        return gamma

    def get_pred(self, graph, users, items):
        self.eval()
        all_users, _ = self.propagate(graph)
        pred = self.forward(all_users, users, items).squeeze()
        return pred

    def get_pred_triple(self, graph, users, pos, neg, global_model=None):
        self.eval()
        all_users, _ = self.propagate(graph)
        pred_pos = self.forward(all_users, users, pos)
        pred_neg = self.forward(all_users, users, neg)
        return pred_pos, pred_neg
    
        
class Client(ClientBase):
    model:model 
    def __init__(self, client_id, model, task, sample_ratio, swap_ratio):
        super().__init__(client_id, model)
        self.task = task.lower()
        self.sample_ratio = sample_ratio
        self.swap_ratio = swap_ratio
        self.samples = None

    def load_model(self, model):
        self.model.load_weights(model)
        self.model.to(self.model.device)
    
    def load_sample(self, sample):
        if sample is not None:
            self.samples = copy.deepcopy(sample)

    def local_train(self, graph, user, local_epoch, dataload,):
        if self.task == "triple":
            users, pos, neg = dataload.get_traindata(user)
            self.__local_data_num = users.size(0)
            if self.samples is not None:
                pos_all = torch.cat([pos, self.samples[0]], dim=0)
                neg_all = torch.cat([neg, self.samples[1]], dim=0)

                pos_len = pos_all.size(0)
                neg_len = neg_all.size(0)

                target_len = max(pos_len, neg_len)

                if pos_len < target_len:
                    extra_idx = torch.randint(0, pos_len, (target_len - pos_len,), device=pos.device)
                    pos_all = torch.cat([pos_all, pos_all[extra_idx]], dim=0)

                if neg_len < target_len:
                    extra_idx = torch.randint(0, neg_len, (target_len - neg_len,), device=neg.device)
                    neg_all = torch.cat([neg_all, neg_all[extra_idx]], dim=0)

                pos, neg = pos_all, neg_all

                user_extra_len = target_len - self.__local_data_num
                if user_extra_len > 0:
                    extra_idx = torch.randint(0, self.__local_data_num, (user_extra_len,), device=users.device)
                    users = torch.cat([users, users[extra_idx]], dim=0)

            self.model.train()
            for _ in range(local_epoch):
                loss = self.model.train_step_triple(graph, users, pos, neg)

            # sample, swap
            self.model.eval()
            with torch.no_grad():
                pred_pos, pred_neg = self.model.get_pred_triple(graph, users_new, pos_new, neg_new)

            num_samples = len(users) 
            sample_size = max(1, int(num_samples * self.sample_ratio)) 
            sampled_idx = torch.randperm(num_samples, device=users.device)[:sample_size] 
            users_new = users[sampled_idx] 
            pos_new = pos[sampled_idx] 
            neg_new = neg[sampled_idx]

            all_scores = torch.cat([pred_pos.view(-1), pred_neg.view(-1)], dim=0)
            all_items = torch.cat([pos_new, neg_new], dim=0)
            all_items, unique_idx = torch.unique(all_items, return_inverse=True)
            all_scores = all_scores[torch.unique(unique_idx)]

            k = min(sample_size, all_scores.size(0) // 2)
            topk_idx = torch.topk(all_scores, k=k, largest=True).indices
            lowk_idx = torch.topk(all_scores, k=k, largest=False).indices
            users_new = users_new[:k]
            pos_new = all_items[topk_idx]
            neg_new = all_items[lowk_idx]

            pred_pos_new = all_scores[topk_idx]

            topk_size = max(1, int(k * self.swap_ratio))
            if topk_size > 0:
                topk_idx = torch.topk(pred_pos_new.view(-1), k=topk_size).indices
                pos_new[topk_idx], neg_new[topk_idx] = neg_new[topk_idx], pos_new[topk_idx]

            return loss, [users_new, pos_new, neg_new]

        else:
            users, items, labels = dataload.get_traindata(user)
            self.__local_data_num = labels.size(0)
            if self.samples is not None:
                items_new = self.samples[1]
                keep_new_idx = [i for i, it in enumerate(items_new) if it not in items]
                if keep_new_idx and len(keep_new_idx)>0:
                    keep_new_idx = torch.tensor(keep_new_idx, device=items_new.device)
                    users_new = self.samples[0][keep_new_idx]
                    items_new = items_new[keep_new_idx]
                    scores_new = self.samples[2][keep_new_idx]
                    users = torch.cat([users, users_new], dim=0)
                    items = torch.cat([items, items_new], dim=0)
                    labels = torch.cat([labels, scores_new], dim=0)
            self.model.train()
            for _ in range(local_epoch):
                loss = self.model.train_step(graph, users, items, labels)
            self.model.eval()
            with torch.no_grad():
                scores = self.model.get_pred(graph, users, items,)
            # sample
            pos_mask = labels > 0
            neg_mask = labels == 0
            pos_items, pos_scores = items[pos_mask], scores[pos_mask]
            neg_items, neg_scores = items[neg_mask], scores[neg_mask]

            sample_size = max(1, int(pos_items.size(0) * self.sample_ratio))
            perm_pos = torch.randperm(pos_items.size(0), device=users.device)[:sample_size]

            sampled_pos_items = pos_items[perm_pos]
            sampled_pos_scores = pos_scores[perm_pos]

            neg_sample_size = sample_size
            perm_neg = torch.randperm(neg_items.size(0), device=users.device)[:neg_sample_size]
            sampled_neg_items = neg_items[perm_neg]
            sampled_neg_scores = neg_scores[perm_neg]

            # swap
            k = max(1, int(sample_size * self.swap_ratio)) 
            topk_idx = torch.topk(sampled_pos_scores.view(-1), k=k, largest=True).indices
            rand_idx = torch.randperm(perm_neg.size(0), device=users.device)[:k]
            sampled_pos_scores[topk_idx], sampled_neg_scores[rand_idx] = sampled_neg_scores[rand_idx], sampled_pos_scores[topk_idx]

            users_new = users[:sampled_pos_items.size(0) + sampled_neg_items.size(0)]
            items_new = torch.cat([sampled_pos_items, sampled_neg_items], dim=0)
            labels_new = torch.cat([sampled_pos_scores, sampled_neg_scores], dim=0)

            return loss, [users_new, items_new, labels_new]
        # logging.info("Client {} for user {}, train loss: {:.6f}".format(self.client_id, user, loss))
    
    def local_data_num(self):
        return self.__local_data_num
    
class Server(ServerBase):
    model:model
    def __init__(self, model, server_epoch, sample_size, sample_portion, graph):
        super().__init__(model)
        self.task = self.model.task
        self.graph = graph
        self.client_models = dict()
        self.server_epoch = server_epoch
        self.sample_freq = int(sample_size * sample_portion)
        self.sample_pred = sample_size - self.sample_freq
        self.samples = None
        self.items = None

    def count_parameters(self):
        # flops, params = profile(self.model, inputs=(torch.tensor(0, dtype=torch.int64, device=self.model.device),
        #                                             torch.tensor(1, dtype=torch.int64, device=self.model.device)))
        # logging.info("FLOPs: {:.8f} MFLOPs".format(flops/ 1e6))
        # logging.info("Param: {:.8f} M".format(params/ 1e6))
        self.model.eval()
        base_model_dict = copy.deepcopy(self.model.state_dict())
        model_size = 0.
        for name in base_model_dict.keys():
            if "embedding_user" in name:
                continue
            else:
                _, param_size = calculate_model_size(base_model_dict[name])
                logging.info("Model {} size: {:.8f}MB".format(name, param_size))
                model_size += param_size
        self.model.load_weights(copy.deepcopy(base_model_dict))
        logging.info("Model all size: {:.8f}MB".format(model_size))
    
    def distribute_model(self, user):
        if user in self.client_models:
            return self.client_models[user]
        return self.model.state_dict()

    def distribute_sample(self, user):
        if self.task != "triple" and self.items is not None:
            self.model.eval()
            unique_it, counts = torch.unique(self.items, return_counts=True)
            freq_sorted_idx = torch.argsort(counts, descending=True)
            topk_freq_items = unique_it[freq_sorted_idx]

            user_freq = torch.tensor([user] * self.sample_freq, dtype=torch.int64, device=self.model.device)
            user_pred = torch.tensor([user] * self.sample_pred, dtype=torch.int64, device=self.model.device)  
            item_freq = topk_freq_items[:self.sample_freq]
            with torch.no_grad():
                score_freq = self.model.get_pred(self.graph, user_freq, item_freq)

            with torch.no_grad():
                scores_all = self.model.get_pred(self.graph, torch.tensor([user] * unique_it.size(0), dtype=torch.int64, device=self.model.device), unique_it)
            score_sorted_idx = torch.argsort(scores_all, descending=True)[:self.sample_pred]
            item_pred = unique_it[score_sorted_idx]
            score_pred = scores_all[score_sorted_idx] 
            
            users = torch.cat([user_freq, user_pred], dim=0).view(-1)
            items = torch.cat([item_freq, item_pred], dim=0).view(-1)
            scores = torch.cat([score_freq, score_pred], dim=0).view(-1)
            return [users, items, scores]
        return self.samples

    def aggregation(self, user_list, model_list, sample_list, num_list, loss_list):
        for i, user in enumerate(user_list):
            self.client_models[user] = model_list[i]
        if self.task == "triple":
            users = torch.cat([sample_list[i][0] for i in range(len(sample_list))], dim=0).view(-1)
            pos = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            neg = torch.cat([sample_list[i][2] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            for _ in range(self.server_epoch):  # server local epoch
                loss = self.model.train_step_triple(self.graph, users, pos, neg)

            self.model.eval()
            with torch.no_grad():
                pred_pos, pred_neg = self.model.get_pred_triple(users, pos, neg)
            all_items = torch.cat([pos, neg], dim=0)
            unique, counts = torch.unique(all_items, return_counts=True)
            freq_topk = torch.topk(counts, k=min(self.sample_freq, len(unique)))
            top_freq_items = unique[freq_topk.indices]

            pos_new_from_freq = top_freq_items[torch.isin(top_freq_items, pos)]
            neg_new_from_freq = top_freq_items[torch.isin(top_freq_items, neg)]

            scores = torch.cat([pred_pos.view(-1), pred_neg.view(-1)], dim=0)
            items = torch.cat([pos, neg], dim=0)
            source = torch.cat([torch.ones_like(pos), torch.zeros_like(neg)], dim=0)  

            score_topk = torch.topk(scores, k=min(self.sample_pred, len(scores)))
            top_score_items = items[score_topk.indices]
            top_score_src = source[score_topk.indices]

            pos_new_from_score = top_score_items[top_score_src == 1]
            neg_new_from_score = top_score_items[top_score_src == 0]

            pos_new = torch.unique(torch.cat([pos_new_from_freq, pos_new_from_score], dim=0))
            neg_new = torch.unique(torch.cat([neg_new_from_freq, neg_new_from_score], dim=0))

            self.samples = [pos_new, neg_new]
        else:
            users = torch.cat([sample_list[i][0] for i in range(len(sample_list))], dim=0).view(-1)
            items = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            scores = torch.cat([sample_list[i][2] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            for _ in range(self.server_epoch):  # server local epoch
                loss = self.model.train_step(self.graph, users, items, scores)
            
            self.items = items
            self.samples = None

        logging.info("Clients average loss: {}, Server average loss: {}".format(torch.mean(torch.tensor(loss_list)), loss))
        
    def evaluate(self, graph, dataload, user_list):
        self.model.eval()
        all_users, _ = self.model.propagate(graph)
        y_pred = []
        y_true = []
        group_id = []
        for user in user_list:
            users, items, labels = dataload.get_testdata(user)
            y_pred.extend(self.model.forward(all_users, users, items).data.cpu().numpy().reshape(-1))
            y_true.extend(labels.data.cpu().numpy().reshape(-1))
            group_id.extend(users.data.cpu().numpy().reshape(-1))
        y_pred = np.array(y_pred, np.float64)
        y_true = np.array(y_true, np.float64)
        group_id = np.array(group_id) if len(group_id) > 0 else None
        val_logs = self.model.evaluate_metrics(y_true, y_pred, self.model.metrcis, group_id)
        logging.info('[Metrics] ' + ' - '.join('{}: {:.6f}'.format(k, v) for k, v in val_logs.items()))
        return val_logs

class FedLightGCN_PTF:
    def __init__(self,
                dataload:BaseDataLoaderFL,
                clients_num_per_turn, 
                local_epoch, 
                train_turn,
                user_num,
                item_num,
                embedding_dim,
                layer_num,
                light,
                server_epoch,
                sample_ratio,
                swap_ratio,
                sample_size, 
                sample_portion,
                embedding_regularizer, 
                net_regularizer, 
                learning_rate,
                optimizer, 
                loss_fn,
                device,
                metrics,
                path,
                task,
                *args, **kwargs
                ):
        self.clients_num_per_turn = clients_num_per_turn
        self.task = task.lower()
        self.local_epoch =  local_epoch
        self.train_turn = train_turn
        self.device = get_device(device)
        self.user_num = user_num
        self.item_num = item_num
        self.dataload = dataload
        self.pre_epoch = kwargs["pre_epoch"]
        self.graph = self.getSpareseGraph(path, light)
        self.server_model = model(user_num=user_num,
            item_num=item_num,
            embedding_dim=embedding_dim,
            layer_num=layer_num,
            light=light,
            task=task.lower(),
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,)
            
        self.server_model.reset_parameters()
        self.server = Server(self.server_model, server_epoch, sample_size, sample_portion, self.graph)
        self.client = Client(client_id=0, model=model(user_num=user_num,
            item_num=item_num,
            embedding_dim=embedding_dim,
            layer_num=layer_num,
            light=light,
            task=task.lower(),
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,), task=task.lower(), sample_ratio=sample_ratio, swap_ratio=swap_ratio) 
        self.g_model = AE(hidden_units = kwargs["g_hidden_units"],
                hidden_activations = kwargs["g_hidden_activations"],
                embedding_dim = kwargs["sen_embedding_dim"], 
                embedding_dim_latent = embedding_dim,
                device = device, 
                embedding_regularizer=0., 
                net_regularizer=1e-2, 
                learning_rate=1e-4,
                optimizer="adam",
                loss_fn = "mse_loss",)
        
        self.clients_num_per_turn = clients_num_per_turn
        self.task = task.lower()
        self.local_epoch =  local_epoch
        self.train_turn = train_turn
        self.device = device
        self.user_num = user_num
        self.item_num = item_num
        self.dataload = dataload
        self.pre_epoch = kwargs["pre_epoch"]

    def getSpareseGraph(self, path, light):
        """
        https://github.com/gusye1234/LightGCN-PyTorch/blob/master/code/dataloader.py#L332
        """
        
        try:
            if light:
                pre_adj_mat = sp.load_npz(path + 'light_graph.npz')
            else:
                pre_adj_mat = sp.load_npz(path + 'gcn_graph.npz')
            norm_adj = pre_adj_mat
        except :
            adj_mat = sp.dok_matrix((self.user_num +  self.item_num, self.user_num + self.item_num), dtype=np.int64)
            adj_mat = adj_mat.tolil()
            R = self.dataload.get_aj_graph()
            adj_mat[:self.user_num, self.user_num:] = R
            adj_mat[self.user_num:, :self.user_num] = R.T
            adj_mat = adj_mat.todok()
            if not light:
                adj_mat = adj_mat + sp.eye(adj_mat.shape[0])
            
            rowsum = np.array(adj_mat.sum(axis=1))
            d_inv = np.power(rowsum, -0.5).flatten()
            d_inv[np.isinf(d_inv)] = 0.
            d_mat = sp.diags(d_inv)
            
            norm_adj = d_mat.dot(adj_mat)
            norm_adj = norm_adj.dot(d_mat)
            norm_adj = norm_adj.tocsr()
            if light:
                sp.save_npz(path + 'light_graph.npz', norm_adj)
            else:
                sp.save_npz(path + 'gcn_graph.npz', norm_adj)

        Graph = self._convert_sp_mat_to_sp_tensor(norm_adj)
        Graph = Graph.coalesce().to(self.device)
        return Graph
    
    def _convert_sp_mat_to_sp_tensor(self, X):
        coo = X.tocoo().astype(np.float32)
        row = torch.Tensor(coo.row).long()
        col = torch.Tensor(coo.col).long()
        index = torch.stack([row, col])
        data = torch.FloatTensor(coo.data)
        return torch.sparse_coo_tensor(index, data, torch.Size(coo.shape), dtype=torch.float32)
    
    def fit(self):
        self.server.count_parameters()
        item_feature = self.dataload.get_item_feature()
        for turn in range(self.pre_epoch):
            loss = self.g_model.train_step(item_feature)
            # logging.info("loss: {} for iter: {}".format(loss, turn))
        latent = self.g_model.get_latent(item_feature)
        self.server.model.embedding_item.weight.data = copy.deepcopy(latent.detach())

        for turn in range(self.train_turn):
            logging.info("********* Train Turn {} *********".format(turn))
            select_users = self.server.select_clients(self.user_num, self.clients_num_per_turn)
            client_sample = []
            client_model = []
            client_local_data_num = []
            losses = []
            for user in select_users:
                self.client.load_client(user)
                self.client.load_model(self.server.distribute_model(user))
                self.client.load_sample(self.server.distribute_sample(user))
                loss, samples= self.client.local_train(self.graph, user, self.local_epoch, self.dataload)
                losses.append(loss)
                client_model.append(self.client.upload_model())
                client_sample.append(samples)
                client_local_data_num.append(self.client.local_data_num())
            self.server.aggregation(select_users, client_model, client_sample, client_local_data_num, losses,)
            torch.cuda.empty_cache()
                
        logging.info("********* Test *********")
        results = self.server.evaluate(self.graph, self.dataload, range(self.user_num))
        return results