"""

"""

import copy
import numpy as np
import torch.nn as nn
import torch
import logging
from dataloaders.BaseDataLoader import *
from framework.fed.client import ClientBase
from framework.fed.server import ServerBase
from framework.modules.models import BaseModel, AE
from framework.modules.layers import MLP_Block
from framework.utils import calculate_model_size
from thop import profile

class model(BaseModel):
    def __init__(self, 
                 user_num, 
                 item_num, 
                 embedding_dim, 
                 task,
                 device, 
                 embedding_regularizer, 
                 net_regularizer, 
                 learning_rate,
                 optimizer,
                 loss_fn,
                 metrics,
                 *args, **kwargs):
        super(__class__, self).__init__(device=device,
                                  embedding_regularizer=embedding_regularizer, 
                                  net_regularizer=net_regularizer,
                                  metrics=metrics)
        self.embedding_user = nn.Embedding(num_embeddings=user_num, embedding_dim=embedding_dim)
        self.embedding_item = nn.Embedding(num_embeddings=item_num, embedding_dim=embedding_dim)
        self.task = task
        self.output_activation= nn.Sigmoid()
        self.reset_parameters()
        self.__init_weight()
        self.compile(optimizer=optimizer, loss=loss_fn, lr=learning_rate)
        self.model_to_device()

    def __init_weight(self, ):
        nn.init.xavier_uniform_(self.embedding_user.weight)
        nn.init.xavier_uniform_(self.embedding_item.weight)

    def forward(self, user_id, item_id):
        output = (self.embedding_user(user_id)*self.embedding_item(item_id)).sum(1)

        if self.task != "triple":
            output = self.output_activation(output)
            if self.task == "regression":
                output = output * 4.0 + 1.0
            return output
        return output

    def train_step(self, users, items, label):
        self.train()
        self.optimizer.zero_grad()
        pred = self.forward(users, items).squeeze()
        loss = self.loss_fn(pred, label, reduction='mean') + self.add_regularization()
        loss.backward()
        self.optimizer.step()
        return loss
    
    def train_step_triple(self, users, pos, neg):
        self.train()
        self.optimizer.zero_grad()
        pred_pos = self.forward(users, pos)
        pred_neg = self.forward(users, neg)
        if len(users) <=0:
            loss = self.loss_fn(pred_pos, pred_neg, )
        else:
            loss = self.loss_fn(pred_pos, pred_neg, ) + self.add_regularization_triple(self.embedding_user.weight[users[0]], 
                                                                                self.embedding_item(pos), self.embedding_item(neg),
                                                                                   )
        loss.backward()
        self.optimizer.step()
        return loss

    def get_pred(self, users, items):
        self.eval()
        pred = self.forward(users, items).squeeze()
        return pred

    def get_pred_triple(self, users, pos, neg, global_model=None):
        self.eval()
        pred_pos = self.forward(users, pos)
        pred_neg = self.forward(users, neg)
        return pred_pos, pred_neg
    

class Client(ClientBase):
    model:model
    def __init__(self, client_id, model, task, sample_ratio, swap_ratio):
        super().__init__(client_id, model)
        self.task = task.lower()
        self.sample_ratio = sample_ratio
        self.swap_ratio = swap_ratio
        self.samples = None

    def load_model(self, model):
        super().load_model(model)
        self.model.to(self.model.device)

    def load_sample(self, sample):
        if sample is not None:
            self.samples = copy.deepcopy(sample)

    def local_train(self, user, local_epoch, dataload):
        if self.task == "triple":
            users, pos, neg = dataload.get_traindata(user)
            self.__local_data_num = users.size(0)
            if self.samples is not None:
                pos_all = torch.cat([pos, self.samples[0]], dim=0)
                neg_all = torch.cat([neg, self.samples[1]], dim=0)

                pos_len = pos_all.size(0)
                neg_len = neg_all.size(0)

                target_len = max(pos_len, neg_len)

                if pos_len < target_len:
                    extra_idx = torch.randint(0, pos_len, (target_len - pos_len,), device=pos.device)
                    pos_all = torch.cat([pos_all, pos_all[extra_idx]], dim=0)

                if neg_len < target_len:
                    extra_idx = torch.randint(0, neg_len, (target_len - neg_len,), device=neg.device)
                    neg_all = torch.cat([neg_all, neg_all[extra_idx]], dim=0)

                pos, neg = pos_all, neg_all

                user_extra_len = target_len - self.__local_data_num
                if user_extra_len > 0:
                    extra_idx = torch.randint(0, self.__local_data_num, (user_extra_len,), device=users.device)
                    users = torch.cat([users, users[extra_idx]], dim=0)

            self.model.train()
            for _ in range(local_epoch):
                loss = self.model.train_step_triple(users, pos, neg)

            # sample, swap
            self.model.eval()
            with torch.no_grad():
                pred_pos, pred_neg = self.model.get_pred_triple(users_new, pos_new, neg_new)

            num_samples = len(users) 
            sample_size = max(1, int(num_samples * self.sample_ratio)) 
            sampled_idx = torch.randperm(num_samples, device=users.device)[:sample_size] 
            users_new = users[sampled_idx] 
            pos_new = pos[sampled_idx] 
            neg_new = neg[sampled_idx]

            all_scores = torch.cat([pred_pos.view(-1), pred_neg.view(-1)], dim=0)
            all_items = torch.cat([pos_new, neg_new], dim=0)
            all_items, unique_idx = torch.unique(all_items, return_inverse=True)
            all_scores = all_scores[torch.unique(unique_idx)]

            k = min(sample_size, all_scores.size(0) // 2)
            topk_idx = torch.topk(all_scores, k=k, largest=True).indices
            lowk_idx = torch.topk(all_scores, k=k, largest=False).indices
            users_new = users_new[:k]
            pos_new = all_items[topk_idx]
            neg_new = all_items[lowk_idx]

            pred_pos_new = all_scores[topk_idx]

            topk_size = max(1, int(k * self.swap_ratio))
            if topk_size > 0:
                topk_idx = torch.topk(pred_pos_new.view(-1), k=topk_size).indices
                pos_new[topk_idx], neg_new[topk_idx] = neg_new[topk_idx], pos_new[topk_idx]

            return loss, [users_new, pos_new, neg_new]
        else:
            users, items, labels = dataload.get_traindata(user)
            self.__local_data_num = users.size(0)
            if self.samples is not None:
                items_new = self.samples[1]
                keep_new_idx = [i for i, it in enumerate(items_new) if it not in items]
                if keep_new_idx and len(keep_new_idx)>0:
                    keep_new_idx = torch.tensor(keep_new_idx, device=items_new.device)
                    users_new = self.samples[0][keep_new_idx]
                    items_new = items_new[keep_new_idx]
                    scores_new = self.samples[2][keep_new_idx]
                    users = torch.cat([users, users_new], dim=0)
                    items = torch.cat([items, items_new], dim=0)
                    labels = torch.cat([labels, scores_new], dim=0)
            self.model.train()
            for _ in range(local_epoch):
                loss = self.model.train_step(users, items, labels)

            self.model.eval()
            with torch.no_grad():
                scores = self.model.get_pred(users, items,)
            # sample
            pos_mask = labels > 0
            neg_mask = labels == 0
            pos_items, pos_scores = items[pos_mask], scores[pos_mask]
            neg_items, neg_scores = items[neg_mask], scores[neg_mask]

            sample_size = max(1, int(pos_items.size(0) * self.sample_ratio))
            perm_pos = torch.randperm(pos_items.size(0), device=users.device)[:sample_size]

            sampled_pos_items = pos_items[perm_pos]
            sampled_pos_scores = pos_scores[perm_pos]

            neg_sample_size = sample_size
            perm_neg = torch.randperm(neg_items.size(0), device=users.device)[:neg_sample_size]
            sampled_neg_items = neg_items[perm_neg]
            sampled_neg_scores = neg_scores[perm_neg]

            # swap
            k = max(1, int(sample_size * self.swap_ratio)) 
            topk_idx = torch.topk(sampled_pos_scores.view(-1), k=k, largest=True).indices
            rand_idx = torch.randperm(perm_neg.size(0), device=users.device)[:k]
            sampled_pos_scores[topk_idx], sampled_neg_scores[rand_idx] = sampled_neg_scores[rand_idx], sampled_pos_scores[topk_idx]

            users_new = users[:sampled_pos_items.size(0) + sampled_neg_items.size(0)]
            items_new = torch.cat([sampled_pos_items, sampled_neg_items], dim=0)
            labels_new = torch.cat([sampled_pos_scores, sampled_neg_scores], dim=0)

            return loss, [users_new, items_new, labels_new]
    
    def local_data_num(self):
        return self.__local_data_num

class Server(ServerBase):
    model:model
    def __init__(self, model, server_epoch, sample_size, sample_portion):
        super().__init__(model)
        self.task = self.model.task
        self.client_models = dict()
        self.server_epoch = server_epoch
        self.sample_freq = int(sample_size * sample_portion)
        self.sample_pred = sample_size - self.sample_freq
        self.samples = None
        self.items = None
    
    def count_parameters(self):
        # flops, params = profile(self.model, inputs=(torch.tensor(0, dtype=torch.int64, device=self.model.device),
        #                                             torch.tensor(1, dtype=torch.int64, device=self.model.device)))
        # logging.info("FLOPs: {:.8f} MFLOPs".format(flops/ 1e6))
        # logging.info("Param: {:.8f} M".format(params/ 1e6))
        self.model.eval()
        base_model_dict = copy.deepcopy(self.model.state_dict())
        model_size = 0.
        for name in base_model_dict.keys():
            if "embedding_user" in name:
                continue
            else:
                _, param_size = calculate_model_size(base_model_dict[name])
                logging.info("Model {} size: {:.8f}MB".format(name, param_size))
                model_size += param_size
        self.model.load_weights(copy.deepcopy(base_model_dict))
        logging.info("Model all size: {:.8f}MB".format(model_size))
    
    def distribute_model(self, user):
        if user in self.client_models:
            return self.client_models[user]
        return self.model.state_dict()

    def distribute_sample(self, user):
        if self.task != "triple" and self.items is not None:
            self.model.eval()
            unique_it, counts = torch.unique(self.items, return_counts=True)
            freq_sorted_idx = torch.argsort(counts, descending=True)
            topk_freq_items = unique_it[freq_sorted_idx]

            user_freq = torch.tensor([user] * self.sample_freq, dtype=torch.int64, device=self.model.device)
            user_pred = torch.tensor([user] * self.sample_pred, dtype=torch.int64, device=self.model.device)  
            item_freq = topk_freq_items[:self.sample_freq]
            with torch.no_grad():
                score_freq = self.model.get_pred(user_freq, item_freq)

            with torch.no_grad():
                scores_all = self.model.get_pred(torch.tensor([user] * unique_it.size(0), dtype=torch.int64, device=self.model.device), unique_it)
            score_sorted_idx = torch.argsort(scores_all, descending=True)[:self.sample_pred]
            item_pred = unique_it[score_sorted_idx]
            score_pred = scores_all[score_sorted_idx] 
            
            users = torch.cat([user_freq, user_pred], dim=0).view(-1)
            items = torch.cat([item_freq, item_pred], dim=0).view(-1)
            scores = torch.cat([score_freq, score_pred], dim=0).view(-1)
            return [users, items, scores]
        return self.samples
    
    def aggregation(self, user_list, model_list, sample_list, num_list, loss_list):
        for i, user in enumerate(user_list):
            self.client_models[user] = model_list[i]
        if self.task == "triple":
            users = torch.cat([sample_list[i][0] for i in range(len(sample_list))], dim=0).view(-1)
            pos = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            neg = torch.cat([sample_list[i][2] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            for _ in range(self.server_epoch):  # server local epoch
                loss = self.model.train_step_triple(users, pos, neg)

            self.model.eval()
            with torch.no_grad():
                pred_pos, pred_neg = self.model.get_pred_triple(users, pos, neg)
            all_items = torch.cat([pos, neg], dim=0)
            unique, counts = torch.unique(all_items, return_counts=True)
            freq_topk = torch.topk(counts, k=min(self.sample_freq, len(unique)))
            top_freq_items = unique[freq_topk.indices]

            pos_new_from_freq = top_freq_items[torch.isin(top_freq_items, pos)]
            neg_new_from_freq = top_freq_items[torch.isin(top_freq_items, neg)]

            scores = torch.cat([pred_pos.view(-1), pred_neg.view(-1)], dim=0)
            items = torch.cat([pos, neg], dim=0)
            source = torch.cat([torch.ones_like(pos), torch.zeros_like(neg)], dim=0)  

            score_topk = torch.topk(scores, k=min(self.sample_pred, len(scores)))
            top_score_items = items[score_topk.indices]
            top_score_src = source[score_topk.indices]

            pos_new_from_score = top_score_items[top_score_src == 1]
            neg_new_from_score = top_score_items[top_score_src == 0]

            pos_new = torch.unique(torch.cat([pos_new_from_freq, pos_new_from_score], dim=0))
            neg_new = torch.unique(torch.cat([neg_new_from_freq, neg_new_from_score], dim=0))

            self.samples = [pos_new, neg_new]
        else:
            users = torch.cat([sample_list[i][0] for i in range(len(sample_list))], dim=0).view(-1)
            items = torch.cat([sample_list[i][1] for i in range(len(sample_list))], dim=0).view(-1)
            scores = torch.cat([sample_list[i][2] for i in range(len(sample_list))], dim=0).view(-1)
            self.model.train()
            for _ in range(self.server_epoch):  # server local epoch
                loss = self.model.train_step(users, items, scores)
            
            self.items = items
            self.samples = None

        logging.info("Clients average loss: {}, Server average loss: {}".format(torch.mean(torch.tensor(loss_list)), loss))

    def evaluate(self, dataload, user_list):
        self.model.eval()
        y_pred = []
        y_true = []
        group_id = []
        for user in user_list:
            users, items, labels = dataload.get_testdata(user)
            y_pred.extend(self.model.forward(users, items).data.cpu().numpy().reshape(-1))
            y_true.extend(labels.data.cpu().numpy().reshape(-1))
            group_id.extend(users.data.cpu().numpy().reshape(-1))
        y_pred = np.array(y_pred, np.float64)
        y_true = np.array(y_true, np.float64)
        group_id = np.array(group_id) if len(group_id) > 0 else None
        val_logs = self.model.evaluate_metrics(y_true, y_pred, self.model.metrcis, group_id)
        logging.info('[Metrics] ' + ' - '.join('{}: {:.6f}'.format(k, v) for k, v in val_logs.items()))
        return val_logs
    

class FedMF_PTF:
    def __init__(self, 
                 dataload:BaseDataLoaderFL,
                 clients_num_per_turn, 
                 local_epoch, 
                 train_turn,
                 user_num,
                 item_num,
                 embedding_dim,
                 server_epoch,
                 sample_ratio,
                 swap_ratio,
                 sample_size, 
                 sample_portion,
                 device, 
                 embedding_regularizer, 
                 net_regularizer, 
                 learning_rate,
                 optimizer, 
                 loss_fn,
                 metrics,
                 task,
                 *args, **kwargs
                 ):
        server_model =  model(
            user_num=user_num,
            item_num=item_num,
            embedding_dim=embedding_dim,
            task=task.lower(),
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,
        )
        server_model.reset_parameters()
        self.server = Server(server_model, server_epoch, sample_size, sample_portion)
        self.client = Client(client_id=0, model=model(
            user_num=user_num,
            item_num=item_num,
            embedding_dim=embedding_dim,
            task=task.lower(),
            device=device,
            embedding_regularizer=embedding_regularizer, 
            net_regularizer=net_regularizer, 
            learning_rate=learning_rate,
            optimizer=optimizer,
            loss_fn=loss_fn,
            metrics=metrics,
        ), task=task.lower(), sample_ratio=sample_ratio, swap_ratio=swap_ratio) 
        self.g_model = AE(hidden_units = kwargs["g_hidden_units"],
                hidden_activations = kwargs["g_hidden_activations"],
                embedding_dim = kwargs["sen_embedding_dim"], 
                embedding_dim_latent = embedding_dim,
                device = device, 
                embedding_regularizer=0., 
                net_regularizer=1e-2, 
                learning_rate=1e-4,
                optimizer="adam",
                loss_fn = "mse_loss",)

        self.clients_num_per_turn = clients_num_per_turn
        self.local_epoch =  local_epoch
        self.train_turn = train_turn
        self.user_num = user_num
        self.task = task.lower()
        self.device = device
        self.dataload = dataload
        self.pre_epoch = kwargs["pre_epoch"]

    def fit(self,):
        self.server.count_parameters()
        item_feature = self.dataload.get_item_feature()
        for turn in range(self.pre_epoch):
            loss = self.g_model.train_step(item_feature)
            # logging.info("loss: {} for iter: {}".format(loss, turn))
        latent = self.g_model.get_latent(item_feature)
        self.server.model.embedding_item.weight.data = copy.deepcopy(latent.detach())

        for turn in range(self.train_turn):
            logging.info("********* Train Turn {} *********".format(turn))
            select_users = self.server.select_clients(self.user_num, self.clients_num_per_turn)
            client_sample = []
            client_model = []
            client_local_data_num = []
            losses = []
            for user in select_users:
                self.client.load_client(user)
                self.client.load_model(self.server.distribute_model(user))
                self.client.load_sample(self.server.distribute_sample(user))
                loss, samples = self.client.local_train(user, self.local_epoch, self.dataload,)
                losses.append(loss)
                client_model.append(self.client.upload_model())
                client_sample.append(samples)
                client_local_data_num.append(self.client.local_data_num())
            self.server.aggregation(select_users, client_model,  client_sample, client_local_data_num, losses, )
            torch.cuda.empty_cache()

        logging.info("********* Test *********")
        results = self.server.evaluate(self.dataload, range(self.user_num))
        return results

