import torch
import copy
import numpy as np
import torch.nn.functional as F

try:
    import habana_frameworks.torch.core as htcore
except ImportError:
    htcore = None

from federatedscope.gfl.loss import GreedyLoss
from federatedscope.gfl.trainer.nodetrainer import NodeFullBatchTrainer


class LocalGenTrainer(NodeFullBatchTrainer):
    def __init__(self,
                 model,
                 data,
                 device,
                 config,
                 only_for_eval=False,
                 monitor=None):
        super(LocalGenTrainer, self).__init__(model, data, device, config,
                                              only_for_eval, monitor)
        self.criterion_num = F.smooth_l1_loss
        self.criterion_feat = GreedyLoss

    def _hook_on_batch_forward(self, ctx):
        batch = ctx.data_batch.to(ctx.device)
        mask = batch['{}_mask'.format(ctx.cur_mode)]
        pred_missing, pred_feat, nc_pred = ctx.model(batch)
        pred_missing, pred_feat, nc_pred = pred_missing[mask], pred_feat[
            mask], nc_pred[mask]
        loss_num = self.criterion_num(pred_missing, batch.num_missing[mask])
        loss_feat = self.criterion_feat(
            pred_feats=pred_feat,
            true_feats=batch.x_missing[mask],
            pred_missing=pred_missing,
            true_missing=batch.num_missing[mask],
            num_pred=self.cfg.fedsageplus.num_pred).requires_grad_()
        loss_clf = ctx.criterion(nc_pred, batch.y[mask])
        ctx.batch_size = torch.sum(mask).item()
        ctx.loss_batch = (self.cfg.fedsageplus.a * loss_num +
                          self.cfg.fedsageplus.b * loss_feat +
                          self.cfg.fedsageplus.c * loss_clf).float()

        ctx.y_true = batch.num_missing[mask]
        ctx.y_prob = pred_missing


class FedGenTrainer(LocalGenTrainer):
    def _hook_on_batch_forward(self, ctx):
        batch = ctx.data_batch.to(ctx.device)
        mask = batch['{}_mask'.format(ctx.cur_mode)]
        pred_missing, pred_feat, nc_pred = ctx.model(batch)
        pred_missing, pred_feat, nc_pred = pred_missing[mask], pred_feat[
            mask], nc_pred[mask]
        loss_num = self.criterion_num(pred_missing, batch.num_missing[mask])
        loss_feat = self.criterion_feat(pred_feats=pred_feat,
                                        true_feats=batch.x_missing[mask],
                                        pred_missing=pred_missing,
                                        true_missing=batch.num_missing[mask],
                                        num_pred=self.cfg.fedsageplus.num_pred)
        loss_clf = ctx.criterion(nc_pred, batch.y[mask])
        ctx.batch_size = torch.sum(mask).item()
        ctx.loss_batch = (self.cfg.fedsageplus.a * loss_num +
                          self.cfg.fedsageplus.b * loss_feat +
                          self.cfg.fedsageplus.c *
                          loss_clf).float() / self.cfg.federate.client_num

        ctx.y_true = batch.num_missing[mask]
        ctx.y_prob = pred_missing

    def update_by_grad(self, grads):
        """
        Arguments:
            grads: grads of other clients to optimize the local model
        :returns:
            state_dict of generation model
        """
        for key in grads.keys():
            if isinstance(grads[key], list):
                grads[key] = torch.FloatTensor(grads[key]).to(self.ctx.device)

        for key, value in self.ctx.model.named_parameters():
            value.grad += grads[key]
        self.ctx.optimizer.step()
        if htcore is not None:
            htcore.mark_step()
        return self.ctx.model.cpu().state_dict()

    def cal_grad(self, raw_data, model_para, embedding, true_missing):
        """
        Arguments:
            raw_data (Pyg.Data): raw graph
            model_para: model parameters
            embedding: output embeddings after local encoder
            true_missing: number of missing node
        :returns:
            grads: grads to optimize the model of other clients
        """
        para_backup = copy.deepcopy(self.ctx.model.cpu().state_dict())

        for key in model_para.keys():
            if isinstance(model_para[key], list):
                model_para[key] = torch.FloatTensor(model_para[key])
        self.ctx.model.load_state_dict(model_para)
        self.ctx.model = self.ctx.model.to(self.ctx.device)
        self.ctx.model.train()

        raw_data = raw_data.to(self.ctx.device)
        embedding = torch.FloatTensor(embedding).to(self.ctx.device)
        true_missing = true_missing.long().to(self.ctx.device)
        pred_missing = self.ctx.model.reg_model(embedding)
        pred_feat = self.ctx.model.gen(embedding)

        # Random pick node and compare its neighbors with predicted nodes
        choice = np.random.choice(raw_data.num_nodes, embedding.shape[0])
        global_target_feat = []
        for c_i in choice:
            neighbors_ids = raw_data.edge_index[1][torch.where(
                raw_data.edge_index[0] == c_i)[0]]
            while len(neighbors_ids) == 0:
                id_i = np.random.choice(raw_data.num_nodes, 1)[0]
                neighbors_ids = raw_data.edge_index[1][torch.where(
                    raw_data.edge_index[0] == id_i)[0]]
            choice_i = np.random.choice(neighbors_ids.detach().cpu().numpy(),
                                        self.cfg.fedsageplus.num_pred)
            for ch_i in choice_i:
                global_target_feat.append(
                    raw_data.x[ch_i].detach().cpu().numpy())
        global_target_feat = np.asarray(global_target_feat).reshape(
            (embedding.shape[0], self.cfg.fedsageplus.num_pred,
             raw_data.num_node_features))
        loss_feat = self.criterion_feat(pred_feats=pred_feat,
                                        true_feats=global_target_feat,
                                        pred_missing=pred_missing,
                                        true_missing=true_missing,
                                        num_pred=self.cfg.fedsageplus.num_pred)
        loss = self.cfg.fedsageplus.b * loss_feat
        loss = (1.0 / self.cfg.federate.client_num * loss).requires_grad_()
        loss.backward()
        if htcore is not None:
            htcore.mark_step()
        grads = {
            key: value.grad
            for key, value in self.ctx.model.named_parameters()
        }
        # Rollback
        self.ctx.model.load_state_dict(para_backup)
        return grads

    @torch.no_grad()
    def embedding(self):
        model = self.ctx.model.to(self.ctx.device)
        data = self.ctx.data['data'].to(self.ctx.device)
        return model.encoder_model(data).to('cpu')
