"""
Wu C, Wu F, Lyu L, et al. A federated graph neural network framework for privacy-preserving personalization
[J]. Nature Communications, 2022, 13(1): 3091.

TY  - JOUR
AU  - Wu, Chuhan
AU  - Wu, Fangzhao
AU  - Lyu, Lingjuan
AU  - Qi, Tao
AU  - Huang, Yongfeng
AU  - Xie, Xing
PY  - 2022
DA  - 2022/06/02
TI  - A federated graph neural network framework for privacy-preserving personalization
JO  - Nature Communications
SP  - 3091
VL  - 13
IS  - 1
AB  - Graph neural network (GNN) is effective in modeling high-order interactions and has been widely used in various personalized applications such as recommendation. However, mainstream personalization methods rely on centralized GNN learning on global graphs, which have considerable privacy risks due to the privacy-sensitive nature of user data. Here, we present a federated GNN framework named FedLightGCN for both effective and privacy-preserving personalization. Through a privacy-preserving model update method, we can collaboratively train GNN models based on decentralized graphs inferred from local data. To further exploit graph information beyond local interactions, we introduce a privacy-preserving graph expansion protocol to incorporate high-order information under privacy protection. Experimental results on six datasets for personalization in different scenarios show that FedLightGCN achieves 4.0% ~ 9.6% lower errors than the state-of-the-art federated personalization methods under good privacy protection. FedLightGCN provides a promising direction to mining decentralized graph data in a privacy-preserving manner for responsible and intelligent personalization.
SN  - 2041-1723
UR  - https://doi.org/10.1038/s41467-022-30714-9
DO  - 10.1038/s41467-022-30714-9
ID  - Wu2022
ER  - 

@article{wu2022federated,
  title={A federated graph neural network framework for privacy-preserving personalization},
  author={Wu, Chuhan and Wu, Fangzhao and Lyu, Lingjuan and Qi, Tao and Huang, Yongfeng and Xie, Xing},
  journal={Nature Communications},
  volume={13},
  number={1},
  pages={3091},
  year={2022},
  publisher={Nature Publishing Group UK London}
}

https://github.com/wuch15/FedLightGCN
"""
import copy
import numpy as np
import torch.nn as nn
import torch
import logging
from dataloaders.BaseDataLoader import *
from framework.fed.client import ClientBase
from framework.fed.server import ServerBase
from framework.modules.models import BaseModel, AE
from framework.modules.layers import MLP_Block
from framework.modules.utils import get_device, compute_divergence, compute_score_divergence, compute_score_similarity
import scipy.sparse as sp
from framework.utils import calculate_model_size
from thop import profile
    
class model(BaseModel):
    def __init__(self, 
                user_num, 
                item_num, 
                embedding_dim, 
                layer_num,
                learning_rate, 
                optimizer,
                loss_fn,
                task,
                light=True,
                device=-1, embedding_regularizer=None, net_regularizer=None, metrics=None, *args, **kwargs):
        super(__class__, self).__init__(device, embedding_regularizer, net_regularizer, metrics)

        self.embedding_user = nn.Embedding(num_embeddings=user_num, embedding_dim=embedding_dim)
        self.embedding_item = nn.Embedding(num_embeddings=item_num, embedding_dim=embedding_dim)
        self.user_num = user_num
        self.item_num = item_num
        self.embedding_dim = embedding_dim
        self.layer_num = layer_num
        self.task = task
        self.output_activation= nn.Sigmoid()
        if task == "regression":
            self.predictor = MLP_Block(input_dim = embedding_dim*2,output_dim=1, output_activation="sigmoid")
            
        self.reset_parameters()
        self.__init_weight()
        self.compile(optimizer=optimizer, loss=loss_fn, lr=learning_rate)
        self.model_to_device()

    def __init_weight(self, ):
        nn.init.normal_(self.embedding_user.weight, std=0.1)
        nn.init.normal_(self.embedding_item.weight, std=0.1)

    def propagate(self, graph):
        users_emb = self.embedding_user.weight
        items_emb = self.embedding_item.weight
        all_emb = torch.cat([users_emb, items_emb])
        for _ in range(self.layer_num):
            all_emb = torch.sparse.mm(graph, all_emb)
        users, items = torch.split(all_emb, [self.user_num, self.item_num])
        return users, items
    
    def forward(self, all_users, users, items):
        user_agg = all_users[users]
        items_emb = self.embedding_item(items)
        return self.predict(user_agg, items_emb)

    def train_step_server(self, graph, sample_list, weight_list):
        self.train()
        self.optimizer.zero_grad()
        all_users, _ = self.propagate(graph)
        losses = []
        for i in range(len(sample_list)):
            users = sample_list[i][0]
            items = sample_list[i][1]
            label = sample_list[i][2]
            pred = self.forward(all_users, users, items).squeeze()
            losses.append(self.loss_fn(pred, label, reduction='mean'))
        loss = sum(n * l for n, l in zip(weight_list, losses))/(sum(weight_list) + 1.e-8) + self.add_regularization()
        loss.backward()
        self.optimizer.step()

        return torch.stack([i.detach() for i in losses]), loss
    
    def train_step(self, graph, users, items, labels):
        self.train()
        self.optimizer.zero_grad()
        all_users, _ = self.propagate(graph)
        pred = self.forward(all_users, users, items)
        loss =  self.loss_fn(pred, labels, ) + self.add_regularization()
        loss.backward()
        self.optimizer.step()
        return loss
    
    def train_step_triple(self, graph, users, pos, neg):
        self.train()
        self.embedding_p.requires_grad_ = False
        self.optimizer.zero_grad()
        all_users, all_items = self.propagate(graph)
        pred_pos = self.forward(all_users, users, pos)
        pred_neg = self.forward(all_users, users, neg)
        loss =  self.loss_fn(pred_pos, pred_neg, ) + self.add_regularization_triple(all_users[users], self.embedding_item(pos), self.embedding_item(neg))
        loss.backward()
        self.optimizer.step()
        return loss

    def predict(self, users_emb, items_emb):
        if self.task == "regression":
            pred = self.predictor(torch.cat([users_emb, items_emb], dim=-1)).squeeze(1)
            gamma = pred * 4.0 + 1.0
        elif self.task == "rank":
            gamma     = (users_emb*items_emb).sum(1)
            gamma = self.output_activation(gamma)
        else:
            gamma     = (users_emb*items_emb).sum(1)
        return gamma

    def get_pred(self, graph, users, items):
        self.eval()
        all_users, _ = self.propagate(graph)
        pred = self.forward(all_users, users, items).squeeze()
        return pred

    def get_pred_triple(self, graph, users, pos, neg, global_model=None):
        self.eval()
        all_users, _ = self.propagate(graph)
        pred_pos = self.forward(all_users, users, pos)
        pred_neg = self.forward(all_users, users, neg)
        return pred_pos, pred_neg
    
        
class Client(ClientBase):
    model:model 
    def __init__(self, client_id, model, task, sample_ratio, sample_bias, swap_ratio, alpha):
        super().__init__(client_id, model)
        self.task = task.lower()
        self.sample_ratio = sample_ratio
        self.sample_bias = sample_bias
        self.swap_ratio = swap_ratio
        self.alpha = alpha
        self.samples = None

    def load_model(self, model):
        self.model.load_weights(model)
        self.model.to(self.model.device)
    
    def load_sample(self, sample):
        if sample is not None:
            self.samples = copy.deepcopy(sample)

    def local_train(self, graph, user, local_epoch, dataload,):
        if self.task == "triple":
            users, pos, neg = dataload.get_traindata(user)
            self.__local_data_num = users.size(0)
            if self.samples is not None:
                pos_all = torch.cat([pos, self.samples[0]], dim=0)
                neg_all = torch.cat([neg, self.samples[1]], dim=0)

                pos_len = pos_all.size(0)
                neg_len = neg_all.size(0)

                target_len = max(pos_len, neg_len)

                if pos_len < target_len:
                    extra_idx = torch.randint(0, pos_len, (target_len - pos_len,), device=pos.device)
                    pos_all = torch.cat([pos_all, pos_all[extra_idx]], dim=0)

                if neg_len < target_len:
                    extra_idx = torch.randint(0, neg_len, (target_len - neg_len,), device=neg.device)
                    neg_all = torch.cat([neg_all, neg_all[extra_idx]], dim=0)

                pos, neg = pos_all, neg_all

                user_extra_len = target_len - self.__local_data_num
                if user_extra_len > 0:
                    extra_idx = torch.randint(0, self.__local_data_num, (user_extra_len,), device=users.device)
                    users = torch.cat([users, users[extra_idx]], dim=0)

            self.model.train()
            for _ in range(local_epoch):
                loss = self.model.train_step_triple(graph, users, pos, neg)

            # sample, swap
            self.model.eval()
            with torch.no_grad():
                pred_pos, pred_neg = self.model.get_pred_triple(graph, users_new, pos_new, neg_new)

            num_samples = len(users) 
            sample_size = max(1, int(num_samples * self.sample_ratio)) 
            sampled_idx = torch.randperm(num_samples, device=users.device)[:sample_size] 
            users_new = users[sampled_idx] 
            pos_new = pos[sampled_idx] 
            neg_new = neg[sampled_idx]

            all_scores = torch.cat([pred_pos.view(-1), pred_neg.view(-1)], dim=0)
            all_items = torch.cat([pos_new, neg_new], dim=0)
            all_items, unique_idx = torch.unique(all_items, return_inverse=True)
            all_scores = all_scores[torch.unique(unique_idx)]

            k = min(sample_size, all_scores.size(0) // 2)
            topk_idx = torch.topk(all_scores, k=k, largest=True).indices
            lowk_idx = torch.topk(all_scores, k=k, largest=False).indices
            users_new = users_new[:k]
            pos_new = all_items[topk_idx]
            neg_new = all_items[lowk_idx]

            pred_pos_new = all_scores[topk_idx]

            topk_size = max(1, int(k * self.swap_ratio))
            if topk_size > 0:
                topk_idx = torch.topk(pred_pos_new.view(-1), k=topk_size).indices
                pos_new[topk_idx], neg_new[topk_idx] = neg_new[topk_idx], pos_new[topk_idx]

            return loss, [users_new, pos_new, neg_new]

        else:
            users, items, labels = dataload.get_traindata(user)
            if self.samples is not None:
                items_new = self.samples[1]
                keep_new_idx = [i for i, it in enumerate(items_new) if it not in items]
                if keep_new_idx and len(keep_new_idx)>0:
                    keep_new_idx = torch.tensor(keep_new_idx, device=items_new.device)
                    users_new = self.samples[0][keep_new_idx]
                    items_new = items_new[keep_new_idx]
                    scores_new = self.samples[2][keep_new_idx]
                    users = torch.cat([users, users_new], dim=0)
                    items = torch.cat([items, items_new], dim=0)
                    labels = torch.cat([labels, scores_new], dim=0)
            self.model.train()
            for _ in range(local_epoch):
                loss = self.model.train_step(graph, users, items, labels)
            self.model.eval()
            with torch.no_grad():
                scores = self.model.get_pred(graph, users, items,)

            alpha= self.alpha
            if self.swap_ratio == 0.:
                scores = alpha * scores + (1.-alpha) * labels

            # sample
            sample_ratio = self.sample_ratio - self.sample_bias + 2 *self.sample_bias * random.random()

            pos_mask = labels > 0
            neg_mask = labels == 0
            pos_items, pos_scores = items[pos_mask], scores[pos_mask]
            neg_items, neg_scores = items[neg_mask], scores[neg_mask]

            sample_size = max(1, int(pos_items.size(0) * sample_ratio))
            perm_pos = torch.randperm(pos_items.size(0), device=users.device)[:sample_size]

            sampled_pos_items = pos_items[perm_pos]
            sampled_pos_scores = pos_scores[perm_pos]

            neg_sample_size = sample_size
            perm_neg = torch.randperm(neg_items.size(0), device=users.device)[:neg_sample_size]
            sampled_neg_items = neg_items[perm_neg]
            sampled_neg_scores = neg_scores[perm_neg]

            # swap
            if self.swap_ratio > 0.:
                k = max(1, int(sample_size * self.swap_ratio)) 
                topk_idx = torch.topk(sampled_pos_scores.view(-1), k=k, largest=True).indices
                rand_idx = torch.randperm(perm_neg.size(0), device=users.device)[:k]
                # sampled_pos_scores[topk_idx], sampled_neg_scores[rand_idx] = sampled_neg_scores[rand_idx], sampled_pos_scores[topk_idx]
                
                pos_labels, neg_labels = labels[pos_mask], labels[neg_mask]
                sampled_pos_labels = pos_labels[perm_pos]
                sampled_neg_labels = neg_labels[perm_neg]
                noise_pos, noise_neg = copy.deepcopy(sampled_neg_scores[rand_idx]), copy.deepcopy(sampled_pos_scores[topk_idx])
                label_pos = sampled_pos_labels == 1
                sampled_pos_scores[label_pos] = alpha * sampled_pos_scores[label_pos] + (1.-alpha) * sampled_pos_labels[label_pos]
                sampled_neg_scores = alpha * sampled_neg_scores + (1.-alpha) * sampled_neg_labels
                sampled_pos_scores[topk_idx], sampled_neg_scores[rand_idx] = noise_pos, noise_neg

            users_new = users[:sampled_pos_items.size(0) + sampled_neg_items.size(0)]
            items_new = torch.cat([sampled_pos_items, sampled_neg_items], dim=0)
            labels_new = torch.cat([sampled_pos_scores, sampled_neg_scores], dim=0)

            self.__local_data_num = users_new.size(0)

            return loss, [users_new, items_new, labels_new]
    
    def local_data_num(self):
        return self.__local_data_num
    
class Policy_Net(BaseModel):
    def __init__(self, 
                 input_dim,
                 hidden_activations, 
                 hidden_units, 
                 output_dim,
                 device, 
                 embedding_regularizer, 
                 net_regularizer, 
                 learning_rate,
                 optimizer,
                 loss_fn,
                 metrics,
                 *args, **kwargs):
        super(__class__, self).__init__(device=device,
                                  embedding_regularizer=embedding_regularizer, 
                                  net_regularizer=net_regularizer,
                                  metrics=metrics)
        self.mlp = MLP_Block(input_dim = input_dim,
                             output_dim = output_dim,
                             hidden_units=hidden_units,
                             hidden_activations=hidden_activations,
                            #  dropout_rates=.0,
                             )
        self.output_activation= nn.Sigmoid()
        self.reset_parameters()
        self.compile(optimizer=optimizer, loss=loss_fn, lr=learning_rate)
        self.model_to_device()

    def forward(self, state):
        return torch.softmax(self.mlp(state).squeeze(-1), dim=0)

    def train_step(self, weights, rewards):
        loss = sum([-torch.log(w + 1e-8) * r for w, r in zip(weights, rewards)])
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def train_step_actor(self, state, action, error):
        probs = self.forward(state).clamp(min=1e-6, max=1-1e-6)
        probs = (1 - 0.1) * probs + 0.1 / probs.size(-1)
        log_prob = torch.log(probs[action] + 1.e-8)
        loss = -log_prob * error.detach()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def train_step_critic(self, state, next_state, reward):
        value = self.mlp(state).squeeze(-1)
        next_value = self.mlp(next_state).squeeze(-1)
        td_target = reward + 0.99 * next_value 
        loss = (td_target - value).pow(2).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss, (td_target - value).detach()

class Server(ServerBase):
    model:model
    def __init__(self, model, server_epoch, sample_size, sample_portion, graph, device, 
                 embedding_regularizer, 
                 net_regularizer, 
                 learning_rate,
                 optimizer, 
                 loss_fn,
                 metrics,
                 item_num):
        super().__init__(model)
        self.task = self.model.task
        self.graph = graph
        self.client_models = dict()
        self.server_epoch = server_epoch
        self.sample_freq = int(sample_size * sample_portion)
        self.sample_pred = sample_size - self.sample_freq
        self.samples = None
        self.items = None
        self.item_num = item_num
        self.policy_net = Policy_Net(
            input_dim = 4,
            hidden_activations='ReLU',
            hidden_units=[32, 16],
            output_dim=1,
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=1e-3,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,)

        self.sample_size = sample_size
        self.states_dict = dict()
        self.actions_dict = dict()
        self.actor = Policy_Net(
            input_dim = 67,
            hidden_activations='ReLU',
            hidden_units=[32, 16],
            output_dim=2,
            device=device,
            embedding_regularizer=1e-6, 
            net_regularizer=1e-6, 
            learning_rate=1e-4,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,)

        self.critic = Policy_Net(
            input_dim = 67,
            hidden_activations='ReLU',
            hidden_units=[32, 16],
            output_dim=1,
            device=device,
            embedding_regularizer=1e-6, 
            net_regularizer=1e-6, 
            learning_rate=1e-4,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,)

        self.clients_translog={}

    def count_parameters(self):
        # flops, params = profile(self.model, inputs=(torch.tensor(0, dtype=torch.int64, device=self.model.device),
        #                                             torch.tensor(1, dtype=torch.int64, device=self.model.device)))
        # logging.info("FLOPs: {:.8f} MFLOPs".format(flops/ 1e6))
        # logging.info("Param: {:.8f} M".format(params/ 1e6))
        self.model.eval()
        base_model_dict = copy.deepcopy(self.model.state_dict())
        model_size = 0.
        for name in base_model_dict.keys():
            if "embedding_user" in name:
                continue
            else:
                _, param_size = calculate_model_size(base_model_dict[name])
                logging.info("Model {} size: {:.8f}MB".format(name, param_size))
                model_size += param_size
        self.model.load_weights(copy.deepcopy(base_model_dict))
        logging.info("Model all size: {:.8f}MB".format(model_size))
    
    def distribute_model(self, user):
        if user in self.client_models:
            return self.client_models[user]
        return self.model.state_dict()

    def distribute_sample(self, user):
        if self.task != "triple" and self.items is not None:
            self.model.eval()
            p_items = self.items
            unique_it, counts = torch.unique(p_items, return_counts=True)
            rand_idx = torch.randperm(unique_it.size(0), device=unique_it.device)[:self.sample_size * 2]
            unique_it, counts = unique_it[rand_idx], counts[rand_idx]
            all_users = torch.tensor([user] * unique_it.size(0), dtype=torch.int64, device=self.model.device)
            with torch.no_grad():
                users_emb = self.model.embedding_user(all_users)
                items_emb = self.model.embedding_item(unique_it)
                scores_all = self.model.get_pred(self.graph, all_users, unique_it)

                states = torch.cat([users_emb, items_emb, counts.unsqueeze(-1), scores_all.unsqueeze(-1)], dim=-1)
                full_states = []
                select_items = []
                actions = []
                for i, state in enumerate(states):
                    num_items = len(select_items)
                    if num_items >= self.sample_size:
                        break
                    full_state = torch.cat([state, torch.tensor([num_items/self.sample_size], dtype=torch.float32, device=self.model.device)])
                    full_states.append(full_state)
                    probs = self.actor(full_state)
                    dist = torch.distributions.Categorical(probs)
                    action = dist.sample().item()
                    if action > 0.:
                        select_items.append(i)
                    actions.append(action)
            items = unique_it[select_items]
            users = torch.tensor([user] * items.size(0) , dtype=torch.int64, device=self.model.device)
            scores = scores_all[select_items]

            self.states_dict[user] = full_states
            self.actions_dict[user] = actions
            self.clients_translog[user] = [items, scores]
            return [users, items, scores]
        return self.samples

    def aggregation(self, user_list, model_list, sample_list, num_list, loss_list):
        for i, user in enumerate(user_list):
            self.client_models[user] = model_list[i]
        if self.task == "triple":
            users = torch.cat([sample_list[i][0] for i in range(len(sample_list))], dim=0).view(-1)
            pos = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            neg = torch.cat([sample_list[i][2] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            for _ in range(self.server_epoch):  # server local epoch
                loss = self.model.train_step_triple(self.graph, users, pos, neg)

            self.model.eval()
            with torch.no_grad():
                pred_pos, pred_neg = self.model.get_pred_triple(users, pos, neg)
            all_items = torch.cat([pos, neg], dim=0)
            unique, counts = torch.unique(all_items, return_counts=True)
            freq_topk = torch.topk(counts, k=min(self.sample_freq, len(unique)))
            top_freq_items = unique[freq_topk.indices]

            pos_new_from_freq = top_freq_items[torch.isin(top_freq_items, pos)]
            neg_new_from_freq = top_freq_items[torch.isin(top_freq_items, neg)]

            scores = torch.cat([pred_pos.view(-1), pred_neg.view(-1)], dim=0)
            items = torch.cat([pos, neg], dim=0)
            source = torch.cat([torch.ones_like(pos), torch.zeros_like(neg)], dim=0)  

            score_topk = torch.topk(scores, k=min(self.sample_pred, len(scores)))
            top_score_items = items[score_topk.indices]
            top_score_src = source[score_topk.indices]

            pos_new_from_score = top_score_items[top_score_src == 1]
            neg_new_from_score = top_score_items[top_score_src == 0]

            pos_new = torch.unique(torch.cat([pos_new_from_freq, pos_new_from_score], dim=0))
            neg_new = torch.unique(torch.cat([neg_new_from_freq, neg_new_from_score], dim=0))

            self.samples = [pos_new, neg_new]
        else:
            items = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            pre_losses, _ = self.model.train_step_server(self.graph, sample_list, num_list) # list
            self.model.eval()
            states = torch.stack([torch.tensor([
                sample_list[i][1].size(0) / sum(num_list) * len(num_list),  # items size
                sample_list[i][2].mean().item(),
                sample_list[i][2].std().item(),
                pre_losses[i].item()
            ], device=self.model.device) for i in range(len(sample_list))])

            weights = self.policy_net(states)

            num_tensor = torch.tensor(num_list, dtype=weights.dtype, device=weights.device)
            sample_weights = weights * num_tensor
            sample_weights = sample_weights / (sample_weights.sum() + 1e-8) 

            self.model.train()
            for _ in range(self.server_epoch-1):  # server local epoch
                losses, loss = self.model.train_step_server(self.graph, sample_list, sample_weights.detach().tolist())
            self.model.eval()

            rewards = - losses
            rewards = (rewards - rewards.min()) / (rewards.max() - rewards.min() + 1e-8) -0.5
            self.policy_net.train_step(weights, rewards)

            rewards_norm = rewards

            divergence_list = []
            for i, user in enumerate(user_list):
                if user in self.clients_translog:
                    divergence_list.append(compute_divergence(sample_list[i][1], self.clients_translog[user][0], self.item_num))
            if len(divergence_list) > 0:
                divergences = torch.tensor(divergence_list, dtype=torch.float32, device=self.model.device)
                divergences = (divergences - divergences.min()) / (divergences.max() - divergences.min() + 1e-8) -0.5

            # score_divergence_list = []
            # for i, user in enumerate(user_list):
            #     if user in self.clients_translog:
            #         score_divergence_list.append(compute_score_similarity(sample_list[i][2], self.clients_translog[user][1]))

            with torch.no_grad():
                score_divergence_list = []
                for i, user in enumerate(user_list):
                    if user in self.clients_translog:
                        score_divergence_list.append(compute_score_similarity(sample_list[i][2], self.model.get_pred(sample_list[i][0], sample_list[i][1])))

            if len(score_divergence_list) > 0:
                score_divergences = torch.tensor(score_divergence_list, dtype=torch.float32, device=self.model.device)
                score_divergences = (score_divergences - score_divergences.min()) / (score_divergences.max() - score_divergences.min() + 1e-8) -0.5

                rewards_norm = (rewards + divergences + (1-score_divergences))/3.

            client_trans_num_list  = []
            for i, user in enumerate(user_list):
                if user in self.states_dict:
                    trans_num = self.clients_translog[user][0].size(0)
                    if trans_num ==0:
                        continue
                    states = self.states_dict[user]
                    actions = self.actions_dict[user]
                    client_trans_num_list.append(trans_num)

                    states_num= len(states)
                    reward = rewards_norm[i] 
                    # - 0.1 * (trans_num/self.sample_size)
                    for i in range(states_num-1):
                        pos_factor = i / states_num - 0.5
                        state_tensor = states[i]
                        action = actions[i]
                        if action >0.:
                            r = reward - torch.abs(reward * 1.0 * pos_factor)
                        else:
                            r = reward + torch.abs(reward * 1.0 * pos_factor)
                        loss_critic, error = self.critic.train_step_critic(state_tensor, states[i+1], r)
                        loss_actor = self.actor.train_step_actor(state_tensor, action, error)
                        # logging.info("action: {}, pos_factor: {}, reward: {}, states_num: {}, trans_num: {}".format(action, pos_factor, r, states_num, trans_num))
            self.items = items
            self.samples = None
            t_num = sum(client_trans_num_list)/len(client_trans_num_list) if len(client_trans_num_list )>0 else 0.

        logging.info("Clients average loss: {}, Server average loss: {}, Clients average transnum: {}".format(torch.mean(torch.tensor(loss_list)), loss, t_num))
        
    def get_user_model(self, user):
        if user in self.client_models:
            return self.client_models[user]
        return self.model

    def evaluate(self, graph, dataload, user_list):
        self.model.eval()
        all_users, _ = self.model.propagate(graph)
        y_pred = []
        y_true = []
        group_id = []
        for user in user_list:
            users, items, labels = dataload.get_testdata(user)
            y_pred.extend(self.model.forward(all_users, users, items).data.cpu().numpy().reshape(-1))
            y_true.extend(labels.data.cpu().numpy().reshape(-1))
            group_id.extend(users.data.cpu().numpy().reshape(-1))
        y_pred = np.array(y_pred, np.float64)
        y_true = np.array(y_true, np.float64)
        group_id = np.array(group_id) if len(group_id) > 0 else None
        val_logs = self.model.evaluate_metrics(y_true, y_pred, self.model.metrcis, group_id)
        logging.info('[Metrics] ' + ' - '.join('{}: {:.6f}'.format(k, v) for k, v in val_logs.items()))
        return val_logs

    def evaluate_personalize(self, graph, dataload, user_list):
        self.model.eval()
        all_users, _ = self.model.propagate(graph)
        y_pred = []
        y_true = []
        group_id = []
        for user in user_list:
            model = self.get_user_model(user)
            users, items, labels = dataload.get_testdata(user)
            y_pred.extend(model.forward(all_users, users, items).data.cpu().numpy().reshape(-1))
            y_true.extend(labels.data.cpu().numpy().reshape(-1))
            group_id.extend(users.data.cpu().numpy().reshape(-1))
        y_pred = np.array(y_pred, np.float64)
        y_true = np.array(y_true, np.float64)
        group_id = np.array(group_id) if len(group_id) > 0 else None
        val_logs = self.model.evaluate_metrics(y_true, y_pred, self.model.metrcis, group_id)
        logging.info('[Metrics] ' + ' - '.join('{}: {:.6f}'.format(k, v) for k, v in val_logs.items()))
        return val_logs

class FedLightGCN_FLRL:
    def __init__(self,
                dataload:BaseDataLoaderFL,
                clients_num_per_turn, 
                local_epoch, 
                train_turn,
                user_num,
                item_num,
                embedding_dim,
                layer_num,
                light,
                server_epoch,
                sample_ratio,
                sample_bias, 
                swap_ratio, 
                alpha, 
                sample_size, 
                sample_portion,
                embedding_regularizer, 
                net_regularizer, 
                learning_rate,
                optimizer, 
                loss_fn,
                device,
                metrics,
                path,
                task,
                *args, **kwargs
                ):
        self.clients_num_per_turn = clients_num_per_turn
        self.task = task.lower()
        self.local_epoch =  local_epoch
        self.train_turn = train_turn
        self.device = get_device(device)
        self.user_num = user_num
        self.item_num = item_num
        self.dataload = dataload
        self.pre_epoch = kwargs["pre_epoch"]
        self.graph = self.getSpareseGraph(path, light)
        self.server_model = model(user_num=user_num,
            item_num=item_num,
            embedding_dim=embedding_dim,
            layer_num=layer_num,
            light=light,
            task=task.lower(),
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,)
            
        self.server_model.reset_parameters()
        self.server = Server(self.server_model, server_epoch, sample_size, sample_portion, self.graph, device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,
            item_num =item_num)
        self.client = Client(client_id=0, model=model(user_num=user_num,
            item_num=item_num,
            embedding_dim=embedding_dim,
            layer_num=layer_num,
            light=light,
            task=task.lower(),
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,), 
            task=task.lower(), 
            sample_ratio=sample_ratio, 
            swap_ratio=swap_ratio, 
            sample_bias=sample_bias, 
            alpha=alpha, 
            ) 
        self.g_model = AE(hidden_units = kwargs["g_hidden_units"],
                hidden_activations = kwargs["g_hidden_activations"],
                embedding_dim = kwargs["sen_embedding_dim"], 
                embedding_dim_latent = embedding_dim,
                device = device, 
                embedding_regularizer=0., 
                net_regularizer=1e-2, 
                learning_rate=1e-4,
                optimizer="adam",
                loss_fn = "mse_loss",)


    def getSpareseGraph(self, path, light):
        """
        https://github.com/gusye1234/LightGCN-PyTorch/blob/master/code/dataloader.py#L332
        """
        
        try:
            if light:
                pre_adj_mat = sp.load_npz(path + 'light_graph.npz')
            else:
                pre_adj_mat = sp.load_npz(path + 'gcn_graph.npz')
            norm_adj = pre_adj_mat
        except :
            adj_mat = sp.dok_matrix((self.user_num +  self.item_num, self.user_num + self.item_num), dtype=np.int64)
            adj_mat = adj_mat.tolil()
            R = self.dataload.get_aj_graph()
            adj_mat[:self.user_num, self.user_num:] = R
            adj_mat[self.user_num:, :self.user_num] = R.T
            adj_mat = adj_mat.todok()
            if not light:
                adj_mat = adj_mat + sp.eye(adj_mat.shape[0])
            
            rowsum = np.array(adj_mat.sum(axis=1))
            d_inv = np.power(rowsum, -0.5).flatten()
            d_inv[np.isinf(d_inv)] = 0.
            d_mat = sp.diags(d_inv)
            
            norm_adj = d_mat.dot(adj_mat)
            norm_adj = norm_adj.dot(d_mat)
            norm_adj = norm_adj.tocsr()
            if light:
                sp.save_npz(path + 'light_graph.npz', norm_adj)
            else:
                sp.save_npz(path + 'gcn_graph.npz', norm_adj)

        Graph = self._convert_sp_mat_to_sp_tensor(norm_adj)
        Graph = Graph.coalesce().to(self.device)
        return Graph
    
    def _convert_sp_mat_to_sp_tensor(self, X):
        coo = X.tocoo().astype(np.float32)
        row = torch.Tensor(coo.row).long()
        col = torch.Tensor(coo.col).long()
        index = torch.stack([row, col])
        data = torch.FloatTensor(coo.data)
        return torch.sparse_coo_tensor(index, data, torch.Size(coo.shape), dtype=torch.float32)
    
    def fit(self):
        self.server.count_parameters()
        item_feature = self.dataload.get_item_feature()
        for turn in range(self.pre_epoch):
            loss = self.g_model.train_step(item_feature)
            # logging.info("loss: {} for iter: {}".format(loss, turn))
        latent = self.g_model.get_latent(item_feature)
        self.server.model.embedding_item.weight.data = copy.deepcopy(latent.detach())

        for turn in range(self.train_turn):
            logging.info("********* Train Turn {} *********".format(turn))
            select_users = self.server.select_clients(self.user_num, self.clients_num_per_turn)
            client_sample = []
            client_model = []
            client_local_data_num = []
            losses = []
            for user in select_users:
                self.client.load_client(user)
                self.client.load_model(self.server.distribute_model(user))
                self.client.load_sample(self.server.distribute_sample(user))
                loss, samples= self.client.local_train(self.graph, user, self.local_epoch, self.dataload)
                losses.append(loss)
                client_model.append(self.client.upload_model())
                client_sample.append(samples)
                client_local_data_num.append(self.client.local_data_num())
            self.server.aggregation(select_users, client_model, client_sample, client_local_data_num, losses,)
            torch.cuda.empty_cache()
                
        logging.info("********* Test *********")
        results = self.server.evaluate(self.graph, self.dataload, range(self.user_num))
        # results = self.server.evaluate_personalize(self.graph, self.dataload, range(self.user_num))
        return results