"""

"""

import copy
import numpy as np
import torch.nn as nn
import torch
import logging
from dataloaders.BaseDataLoader import *
from framework.fed.client import ClientBase
from framework.fed.server import ServerBase
from framework.modules.models import BaseModel, AE
from framework.modules.layers import MLP_Block
from framework.modules.utils import compute_divergence, compute_score_divergence, compute_score_similarity
from framework.utils import calculate_model_size
from thop import profile

class model(BaseModel):
    def __init__(self, 
                 user_num, 
                 item_num, 
                 embedding_dim, 
                 task,
                 device, 
                 embedding_regularizer, 
                 net_regularizer, 
                 learning_rate,
                 optimizer,
                 loss_fn,
                 metrics,
                 *args, **kwargs):
        super(__class__, self).__init__(device=device,
                                  embedding_regularizer=embedding_regularizer, 
                                  net_regularizer=net_regularizer,
                                  metrics=metrics)
        self.embedding_user = nn.Embedding(num_embeddings=user_num, embedding_dim=embedding_dim)
        self.embedding_item = nn.Embedding(num_embeddings=item_num, embedding_dim=embedding_dim)
        self.task = task
        self.output_activation= nn.Sigmoid()
        self.reset_parameters()
        self.__init_weight()
        self.compile(optimizer=optimizer, loss=loss_fn, lr=learning_rate)
        self.model_to_device()

    def __init_weight(self, ):
        nn.init.xavier_uniform_(self.embedding_user.weight)
        nn.init.xavier_uniform_(self.embedding_item.weight)

    def forward(self, user_id, item_id):
        output = (self.embedding_user(user_id)*self.embedding_item(item_id)).sum(1)

        if self.task != "triple":
            output = self.output_activation(output)
            if self.task == "regression":
                output = output * 4.0 + 1.0
            return output
        return output

    def train_step_server(self, sample_list, weight_list):
        self.train()
        self.optimizer.zero_grad()
        losses = []
        for i in range(len(sample_list)):
            users = sample_list[i][0]
            items = sample_list[i][1]
            label = sample_list[i][2]
            pred = self.forward(users, items).squeeze()
            losses.append(self.loss_fn(pred, label, reduction='mean'))
        loss = sum(n * l for n, l in zip(weight_list, losses))/(sum(weight_list) + 1.e-8) + self.add_regularization()
        loss.backward()
        self.optimizer.step()

        return torch.stack([i.detach() for i in losses]), loss

    def train_step(self, users, items, label):
        self.train()
        self.optimizer.zero_grad()
        pred = self.forward(users, items).squeeze()
        loss = self.loss_fn(pred, label, reduction='mean') + self.add_regularization()
        loss.backward()
        self.optimizer.step()
        return loss
    
    def train_step_triple(self, users, pos, neg):
        self.train()
        self.optimizer.zero_grad()
        pred_pos = self.forward(users, pos)
        pred_neg = self.forward(users, neg)
        if len(users) <=0:
            loss = self.loss_fn(pred_pos, pred_neg, )
        else:
            loss = self.loss_fn(pred_pos, pred_neg, ) + self.add_regularization_triple(self.embedding_user.weight[users[0]], 
                                                                                self.embedding_item(pos), self.embedding_item(neg),
                                                                                   )
        loss.backward()
        self.optimizer.step()
        return loss

    def get_pred(self, users, items):
        self.eval()
        pred = self.forward(users, items).squeeze()
        return pred

    def get_pred_triple(self, users, pos, neg, global_model=None):
        self.eval()
        pred_pos = self.forward(users, pos)
        pred_neg = self.forward(users, neg)
        return pred_pos, pred_neg
    

class Client(ClientBase):
    model:model
    def __init__(self, client_id, model, task, sample_ratio, sample_bias, swap_ratio, alpha):
        super().__init__(client_id, model)
        self.task = task.lower()
        self.sample_ratio = sample_ratio
        self.sample_bias = sample_bias
        self.swap_ratio = swap_ratio
        self.alpha = alpha
        self.samples = None

    def load_model(self, model):
        super().load_model(model)
        self.model.to(self.model.device)

    def load_sample(self, sample):
        if sample is not None:
            self.samples = copy.deepcopy(sample)

    def local_train(self, user, local_epoch, dataload):
        if self.task == "triple":
            users, pos, neg = dataload.get_traindata(user)
            self.__local_data_num = users.size(0)
            if self.samples is not None:
                pos_all = torch.cat([pos, self.samples[0]], dim=0)
                neg_all = torch.cat([neg, self.samples[1]], dim=0)

                pos_len = pos_all.size(0)
                neg_len = neg_all.size(0)

                target_len = max(pos_len, neg_len)

                if pos_len < target_len:
                    extra_idx = torch.randint(0, pos_len, (target_len - pos_len,), device=pos.device)
                    pos_all = torch.cat([pos_all, pos_all[extra_idx]], dim=0)

                if neg_len < target_len:
                    extra_idx = torch.randint(0, neg_len, (target_len - neg_len,), device=neg.device)
                    neg_all = torch.cat([neg_all, neg_all[extra_idx]], dim=0)

                pos, neg = pos_all, neg_all

                user_extra_len = target_len - self.__local_data_num
                if user_extra_len > 0:
                    extra_idx = torch.randint(0, self.__local_data_num, (user_extra_len,), device=users.device)
                    users = torch.cat([users, users[extra_idx]], dim=0)

            self.model.train()
            for _ in range(local_epoch):
                loss = self.model.train_step_triple(users, pos, neg)

            # sample, swap
            self.model.eval()
            with torch.no_grad():
                pred_pos, pred_neg = self.model.get_pred_triple(users_new, pos_new, neg_new)

            num_samples = len(users) 
            sample_size = max(1, int(num_samples * self.sample_ratio)) 
            sampled_idx = torch.randperm(num_samples, device=users.device)[:sample_size] 
            users_new = users[sampled_idx] 
            pos_new = pos[sampled_idx] 
            neg_new = neg[sampled_idx]

            all_scores = torch.cat([pred_pos.view(-1), pred_neg.view(-1)], dim=0)
            all_items = torch.cat([pos_new, neg_new], dim=0)
            all_items, unique_idx = torch.unique(all_items, return_inverse=True)
            all_scores = all_scores[torch.unique(unique_idx)]

            k = min(sample_size, all_scores.size(0) // 2)
            topk_idx = torch.topk(all_scores, k=k, largest=True).indices
            lowk_idx = torch.topk(all_scores, k=k, largest=False).indices
            users_new = users_new[:k]
            pos_new = all_items[topk_idx]
            neg_new = all_items[lowk_idx]

            pred_pos_new = all_scores[topk_idx]

            topk_size = max(1, int(k * self.swap_ratio))
            if topk_size > 0:
                topk_idx = torch.topk(pred_pos_new.view(-1), k=topk_size).indices
                pos_new[topk_idx], neg_new[topk_idx] = neg_new[topk_idx], pos_new[topk_idx]

            return loss, [users_new, pos_new, neg_new]
        else:
            users, items, labels = dataload.get_traindata(user)
            if self.samples is not None:
                items_new = self.samples[1]
                keep_new_idx = [i for i, it in enumerate(items_new) if it not in items]
                if keep_new_idx and len(keep_new_idx)>0:
                    keep_new_idx = torch.tensor(keep_new_idx, device=items_new.device)
                    users_new = self.samples[0][keep_new_idx]
                    items_new = items_new[keep_new_idx]
                    scores_new = self.samples[2][keep_new_idx]
                    users = torch.cat([users, users_new], dim=0)
                    items = torch.cat([items, items_new], dim=0)
                    labels = torch.cat([labels, scores_new], dim=0)
            self.model.train()
            for _ in range(local_epoch):
                loss = self.model.train_step(users, items, labels)

            self.model.eval()
            with torch.no_grad():
                scores = self.model.get_pred(users, items,)

            alpha= self.alpha
            if self.swap_ratio == 0.:
                scores = alpha * scores + (1.-alpha) * labels

            # sample
            sample_ratio = self.sample_ratio - self.sample_bias + 2 *self.sample_bias * random.random()
            
            pos_mask = labels > 0
            neg_mask = labels == 0
            pos_items, pos_scores = items[pos_mask], scores[pos_mask]
            neg_items, neg_scores = items[neg_mask], scores[neg_mask]

            sample_size = max(1, int(pos_items.size(0) * sample_ratio))
            perm_pos = torch.randperm(pos_items.size(0), device=users.device)[:sample_size]

            sampled_pos_items = pos_items[perm_pos]
            sampled_pos_scores = pos_scores[perm_pos]

            neg_sample_size = sample_size
            perm_neg = torch.randperm(neg_items.size(0), device=users.device)[:neg_sample_size]
            sampled_neg_items = neg_items[perm_neg]
            sampled_neg_scores = neg_scores[perm_neg]

            # swap
            if self.swap_ratio > 0.:
                k = max(1, int(sample_size * self.swap_ratio)) 
                topk_idx = torch.topk(sampled_pos_scores.view(-1), k=k, largest=True).indices
                rand_idx = torch.randperm(perm_neg.size(0), device=users.device)[:k]
                # sampled_pos_scores[topk_idx], sampled_neg_scores[rand_idx] = sampled_neg_scores[rand_idx], sampled_pos_scores[topk_idx]
                
                pos_labels, neg_labels = labels[pos_mask], labels[neg_mask]
                sampled_pos_labels = pos_labels[perm_pos]
                sampled_neg_labels = neg_labels[perm_neg]
                noise_pos, noise_neg = copy.deepcopy(sampled_neg_scores[rand_idx]), copy.deepcopy(sampled_pos_scores[topk_idx])
                label_pos = sampled_pos_labels == 1
                sampled_pos_scores[label_pos] = alpha * sampled_pos_scores[label_pos] + (1.-alpha) * sampled_pos_labels[label_pos]
                sampled_neg_scores = alpha * sampled_neg_scores + (1.-alpha) * sampled_neg_labels
                sampled_pos_scores[topk_idx], sampled_neg_scores[rand_idx] = noise_pos, noise_neg

            users_new = users[:sampled_pos_items.size(0) + sampled_neg_items.size(0)]
            items_new = torch.cat([sampled_pos_items, sampled_neg_items], dim=0)
            labels_new = torch.cat([sampled_pos_scores, sampled_neg_scores], dim=0)

            self.__local_data_num = users_new.size(0)

            return loss, [users_new, items_new, labels_new]
    
    def local_data_num(self):
        return self.__local_data_num

class Policy_Net(BaseModel):
    def __init__(self, 
                 input_dim,
                 hidden_activations, 
                 hidden_units, 
                 output_dim,
                 device, 
                 embedding_regularizer, 
                 net_regularizer, 
                 learning_rate,
                 optimizer,
                 loss_fn,
                 metrics,
                 *args, **kwargs):
        super(__class__, self).__init__(device=device,
                                  embedding_regularizer=embedding_regularizer, 
                                  net_regularizer=net_regularizer,
                                  metrics=metrics)
        self.mlp = MLP_Block(input_dim = input_dim,
                             output_dim = output_dim,
                             hidden_units=hidden_units,
                             hidden_activations=hidden_activations,
                            #  dropout_rates=.0,
                             )
        self.output_activation= nn.Sigmoid()
        self.reset_parameters()
        self.compile(optimizer=optimizer, loss=loss_fn, lr=learning_rate)
        self.model_to_device()

    def forward(self, state):
        return torch.softmax(self.mlp(state).squeeze(-1), dim=0)

    def train_step(self, weights, rewards):
        loss = sum([-torch.log(w + 1e-8) * r for w, r in zip(weights, rewards)])
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def train_step_actor(self, state, action, error):
        probs = self.forward(state).clamp(min=1e-6, max=1-1e-6)
        probs = (1 - 0.1) * probs + 0.1 / probs.size(-1)
        log_prob = torch.log(probs[action] + 1.e-8)
        loss = -log_prob * error.detach()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def train_step_critic(self, state, next_state, reward):
        value = self.mlp(state).squeeze(-1)
        next_value = self.mlp(next_state).squeeze(-1)
        td_target = reward + 0.99 * next_value 
        loss = (td_target - value).pow(2).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss, (td_target - value).detach()
        

class Server(ServerBase):
    model:model
    def __init__(self, model, server_epoch, sample_size, sample_portion, device, 
                 embedding_regularizer, 
                 net_regularizer, 
                 learning_rate,
                 optimizer, 
                 loss_fn,
                 metrics,
                 item_num):
        super().__init__(model)
        self.task = self.model.task
        self.client_models = dict()
        self.server_epoch = server_epoch
        self.sample_freq = int(sample_size * sample_portion)
        self.sample_pred = sample_size - self.sample_freq
        self.samples = None
        self.items = None
        self.item_num = item_num
        self.policy_net = Policy_Net(
            input_dim = 4,
            hidden_activations='ReLU',
            hidden_units=[32, 16],
            output_dim=1,
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=1e-3,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,)

        self.sample_size = sample_size
        self.states_dict = dict()
        self.actions_dict = dict()
        self.actor = Policy_Net(
            input_dim = 67,
            hidden_activations='ReLU',
            hidden_units=[32, 16],
            output_dim=2,
            device=device,
            embedding_regularizer=1e-6, 
            net_regularizer=1e-6, 
            learning_rate=1e-4,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,)

        self.critic = Policy_Net(
            input_dim = 67,
            hidden_activations='ReLU',
            hidden_units=[32, 16],
            output_dim=1,
            device=device,
            embedding_regularizer=1e-6, 
            net_regularizer=1e-6, 
            learning_rate=1e-4,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,)

        self.clients_translog={}
    
    def count_parameters(self):
        # flops, params = profile(self.model, inputs=(torch.tensor(0, dtype=torch.int64, device=self.model.device),
        #                                             torch.tensor(1, dtype=torch.int64, device=self.model.device)))
        # logging.info("FLOPs: {:.8f} MFLOPs".format(flops/ 1e6))
        # logging.info("Param: {:.8f} M".format(params/ 1e6))
        self.model.eval()
        base_model_dict = copy.deepcopy(self.model.state_dict())
        model_size = 0.
        for name in base_model_dict.keys():
            if "embedding_user" in name:
                continue
            else:
                _, param_size = calculate_model_size(base_model_dict[name])
                logging.info("Model {} size: {:.8f}MB".format(name, param_size))
                model_size += param_size
        self.model.load_weights(copy.deepcopy(base_model_dict))
        logging.info("Model all size: {:.8f}MB".format(model_size))
    
    def distribute_model(self, user):
        if user in self.client_models:
            return self.client_models[user]
        return self.model.state_dict()

    def distribute_sample(self, user):
        if self.task != "triple" and self.items is not None:
            self.model.eval()
            p_items = self.items
            unique_it, counts = torch.unique(p_items, return_counts=True)
            rand_idx = torch.randperm(unique_it.size(0), device=unique_it.device)[:self.sample_size * 2]
            unique_it, counts = unique_it[rand_idx], counts[rand_idx]
            all_users = torch.tensor([user] * unique_it.size(0), dtype=torch.int64, device=self.model.device)
            with torch.no_grad():
                users_emb = self.model.embedding_user(all_users)
                items_emb = self.model.embedding_item(unique_it)
                scores_all = self.model.get_pred(all_users, unique_it)

                states = torch.cat([users_emb, items_emb, counts.unsqueeze(-1), scores_all.unsqueeze(-1)], dim=-1)
                full_states = []
                select_items = []
                actions = []
                for i, state in enumerate(states):
                    num_items = len(select_items)
                    if num_items >= self.sample_size:
                        break
                    full_state = torch.cat([state, torch.tensor([num_items/self.sample_size], dtype=torch.float32, device=self.model.device)])
                    full_states.append(full_state)
                    probs = self.actor(full_state)
                    dist = torch.distributions.Categorical(probs)
                    action = dist.sample().item()
                    if action > 0.:
                        select_items.append(i)
                    actions.append(action)
            items = unique_it[select_items]
            users = torch.tensor([user] * items.size(0) , dtype=torch.int64, device=self.model.device)
            scores = scores_all[select_items]

            self.states_dict[user] = full_states
            self.actions_dict[user] = actions
            self.clients_translog[user] = [items, scores]
            return [users, items, scores]
        return self.samples
    
    def aggregation(self, user_list, model_list, sample_list, num_list, loss_list):
        for i, user in enumerate(user_list):
            self.client_models[user] = model_list[i]
        if self.task == "triple":
            users = torch.cat([sample_list[i][0] for i in range(len(sample_list))], dim=0).view(-1)
            pos = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            neg = torch.cat([sample_list[i][2] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            for _ in range(self.server_epoch):  # server local epoch
                loss = self.model.train_step_triple(users, pos, neg)

            self.model.eval()
            with torch.no_grad():
                pred_pos, pred_neg = self.model.get_pred_triple(users, pos, neg)
            all_items = torch.cat([pos, neg], dim=0)
            unique, counts = torch.unique(all_items, return_counts=True)
            freq_topk = torch.topk(counts, k=min(self.sample_freq, len(unique)))
            top_freq_items = unique[freq_topk.indices]

            pos_new_from_freq = top_freq_items[torch.isin(top_freq_items, pos)]
            neg_new_from_freq = top_freq_items[torch.isin(top_freq_items, neg)]

            scores = torch.cat([pred_pos.view(-1), pred_neg.view(-1)], dim=0)
            items = torch.cat([pos, neg], dim=0)
            source = torch.cat([torch.ones_like(pos), torch.zeros_like(neg)], dim=0)  

            score_topk = torch.topk(scores, k=min(self.sample_pred, len(scores)))
            top_score_items = items[score_topk.indices]
            top_score_src = source[score_topk.indices]

            pos_new_from_score = top_score_items[top_score_src == 1]
            neg_new_from_score = top_score_items[top_score_src == 0]

            pos_new = torch.unique(torch.cat([pos_new_from_freq, pos_new_from_score], dim=0))
            neg_new = torch.unique(torch.cat([neg_new_from_freq, neg_new_from_score], dim=0))

            self.samples = [pos_new, neg_new]
        else:
            users = torch.cat([sample_list[i][0] for i in range(len(sample_list))], dim=0).view(-1)
            items = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            scores = torch.cat([sample_list[i][2] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            for _ in range(self.server_epoch):  # server local epoch
                loss = self.model.train_step(users, items, scores)
            
            self.items = items
            self.samples = None
            items = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            pre_losses, _ = self.model.train_step_server(sample_list, num_list) # list
            self.model.eval()
            states = torch.stack([torch.tensor([
                sample_list[i][1].size(0) / sum(num_list) * len(num_list),  # items size
                sample_list[i][2].mean().item(),
                sample_list[i][2].std().item(),
                pre_losses[i].item()
            ], device=self.model.device) for i in range(len(sample_list))])

            weights = self.policy_net(states)

            num_tensor = torch.tensor(num_list, dtype=weights.dtype, device=weights.device)
            sample_weights = weights * num_tensor
            sample_weights = sample_weights / (sample_weights.sum() + 1e-8) 

            self.model.train()
            for _ in range(self.server_epoch-1):  # server local epoch
                losses, loss = self.model.train_step_server(sample_list, sample_weights.detach().tolist())
            self.model.eval()

            rewards = - losses
            rewards = (rewards - rewards.min()) / (rewards.max() - rewards.min() + 1e-8) -0.5
            self.policy_net.train_step(weights, rewards)

            rewards_norm = rewards

            divergence_list = []
            for i, user in enumerate(user_list):
                if user in self.clients_translog:
                    divergence_list.append(compute_divergence(sample_list[i][1], self.clients_translog[user][0], self.item_num))
            if len(divergence_list) > 0:
                divergences = torch.tensor(divergence_list, dtype=torch.float32, device=self.model.device)
                divergences = (divergences - divergences.min()) / (divergences.max() - divergences.min() + 1e-8) - 0.5

            # score_divergence_list = []
            # for i, user in enumerate(user_list):
            #     if user in self.clients_translog:
            #         score_divergence_list.append(compute_score_similarity(sample_list[i][2], self.clients_translog[user][1]))

            with torch.no_grad():
                score_divergence_list = []
                for i, user in enumerate(user_list):
                    if user in self.clients_translog:
                        score_divergence_list.append(compute_score_similarity(sample_list[i][2], self.model.get_pred(sample_list[i][0], sample_list[i][1])))
        
            if len(score_divergence_list) > 0:
                score_divergences = torch.tensor(score_divergence_list, dtype=torch.float32, device=self.model.device)
                score_divergences = (score_divergences - score_divergences.min()) / (score_divergences.max() - score_divergences.min() + 1e-8) - 0.5

                rewards_norm = (rewards + divergences + (1-score_divergences))/3.

            client_trans_num_list  = []
            for i, user in enumerate(user_list):
                if user in self.states_dict:
                    trans_num = self.clients_translog[user][0].size(0)
                    if trans_num ==0:
                        continue
                    states = self.states_dict[user]
                    actions = self.actions_dict[user]
                    client_trans_num_list.append(trans_num)

                    states_num= len(states)
                    reward = rewards_norm[i] 
                    # - 0.1 * (trans_num/self.sample_size)
                    for i in range(states_num-1):
                        pos_factor = i / states_num - 0.5
                        state_tensor = states[i]
                        action = actions[i]
                        if action >0.:
                            r = reward - torch.abs(reward * 1.0 * pos_factor)
                        else:
                            r = reward + torch.abs(reward * 1.0 * pos_factor)
                        loss_critic, error = self.critic.train_step_critic(state_tensor, states[i+1], r)
                        loss_actor = self.actor.train_step_actor(state_tensor, action, error)
                        # logging.info("action: {}, pos_factor: {}, reward: {}, states_num: {}, trans_num: {}".format(action, pos_factor, r, states_num, trans_num))
            self.items = items
            self.samples = None
            t_num = sum(client_trans_num_list)/len(client_trans_num_list) if len(client_trans_num_list )>0 else 0.

        logging.info("Clients average loss: {}, Server average loss: {}, Clients average transnum: {}".format(torch.mean(torch.tensor(loss_list)), loss, t_num))

    def get_user_model(self, user):
        if user in self.client_models:
            return self.client_models[user]
        return self.model

    def evaluate(self, dataload, user_list):
        self.model.eval()
        y_pred = []
        y_true = []
        group_id = []
        for user in user_list:
            users, items, labels = dataload.get_testdata(user)
            y_pred.extend(self.model.forward(users, items).data.cpu().numpy().reshape(-1))
            y_true.extend(labels.data.cpu().numpy().reshape(-1))
            group_id.extend(users.data.cpu().numpy().reshape(-1))
        y_pred = np.array(y_pred, np.float64)
        y_true = np.array(y_true, np.float64)
        group_id = np.array(group_id) if len(group_id) > 0 else None
        val_logs = self.model.evaluate_metrics(y_true, y_pred, self.model.metrcis, group_id)
        logging.info('[Metrics] ' + ' - '.join('{}: {:.6f}'.format(k, v) for k, v in val_logs.items()))
        return val_logs

    def evaluate_personalize(self, dataload, user_list):
        self.model.eval()
        y_pred = []
        y_true = []
        group_id = []
        for user in user_list:
            model = self.get_user_model(user)
            users, items, labels = dataload.get_testdata(user)
            y_pred.extend(model.forward(users, items).data.cpu().numpy().reshape(-1))
            y_true.extend(labels.data.cpu().numpy().reshape(-1))
            group_id.extend(users.data.cpu().numpy().reshape(-1))
        y_pred = np.array(y_pred, np.float64)
        y_true = np.array(y_true, np.float64)
        group_id = np.array(group_id) if len(group_id) > 0 else None
        val_logs = self.model.evaluate_metrics(y_true, y_pred, self.model.metrcis, group_id)
        logging.info('[Metrics] ' + ' - '.join('{}: {:.6f}'.format(k, v) for k, v in val_logs.items()))
        return val_logs
    

class FedMF_FLRL:
    def __init__(self, 
                 dataload:BaseDataLoaderFL,
                 clients_num_per_turn, 
                 local_epoch, 
                 train_turn,
                 user_num,
                 item_num,
                 embedding_dim,
                 server_epoch,
                 sample_ratio, 
                 sample_bias, 
                 swap_ratio, 
                 alpha, 
                 sample_size, 
                 sample_portion,
                 device, 
                 embedding_regularizer, 
                 net_regularizer, 
                 learning_rate,
                 optimizer, 
                 loss_fn,
                 metrics,
                 task,
                 *args, **kwargs
                 ):
        server_model =  model(
            user_num=user_num,
            item_num=item_num,
            embedding_dim=embedding_dim,
            task=task.lower(),
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,
        )
        server_model.reset_parameters()
        self.server = Server(server_model, server_epoch, sample_size, sample_portion, device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,
            item_num =item_num)
        self.client = Client(client_id=0, model=model(
            user_num=user_num,
            item_num=item_num,
            embedding_dim=embedding_dim,
            task=task.lower(),
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,
        ), task=task.lower(), 
            sample_ratio=sample_ratio, 
            swap_ratio=swap_ratio, 
            sample_bias=sample_bias, 
            alpha=alpha, 
            ) 
        self.g_model = AE(hidden_units = kwargs["g_hidden_units"],
                hidden_activations = kwargs["g_hidden_activations"],
                embedding_dim = kwargs["sen_embedding_dim"], 
                embedding_dim_latent = embedding_dim,
                device = device, 
                embedding_regularizer=0., 
                net_regularizer=1e-2, 
                learning_rate=1e-4,
                optimizer="adam",
                loss_fn = "mse_loss",)
        self.clients_num_per_turn = clients_num_per_turn
        self.local_epoch =  local_epoch
        self.train_turn = train_turn
        self.user_num = user_num
        self.task = task.lower()
        self.device = device
        self.dataload = dataload
        self.pre_epoch = kwargs["pre_epoch"]

    def fit(self,):
        self.server.count_parameters()
        item_feature = self.dataload.get_item_feature()
        for turn in range(self.pre_epoch):
            loss = self.g_model.train_step(item_feature)
            # logging.info("loss: {} for iter: {}".format(loss, turn))
        latent = self.g_model.get_latent(item_feature)
        self.server.model.embedding_item.weight.data = copy.deepcopy(latent.detach())

        for turn in range(self.train_turn):
            logging.info("********* Train Turn {} *********".format(turn))
            select_users = self.server.select_clients(self.user_num, self.clients_num_per_turn)
            client_sample = []
            client_model = []
            client_local_data_num = []
            losses = []
            for user in select_users:
                self.client.load_client(user)
                self.client.load_model(self.server.distribute_model(user))
                self.client.load_sample(self.server.distribute_sample(user))
                loss, samples = self.client.local_train(user, self.local_epoch, self.dataload,)
                losses.append(loss)
                client_model.append(self.client.upload_model())
                client_sample.append(samples)
                client_local_data_num.append(self.client.local_data_num())
            self.server.aggregation(select_users, client_model,  client_sample, client_local_data_num, losses, )
            torch.cuda.empty_cache()

        logging.info("********* Test *********")
        results = self.server.evaluate(self.dataload, range(self.user_num))
        # results = self.server.evaluate_personalize(self.dataload, range(self.user_num))
        return results

