#  Explicit pairwise factorized graph neural network for semi-supervised node classification, UAI 2021
# The source of model's main code: https://github.com/YuWang-1024/EPFGNN

import os
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter
import math
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, GCNConv
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import degree
import copy
import time
from sklearn.metrics import accuracy_score as ACC



class EPFGNN(nn.Module):

    def __init__(
            self,
            in_features: int,
            class_num: int,
            device,
            args,
        ) -> None:
        super().__init__()
        #------------- Parameters ----------------
        self.device = device
        self.class_num = class_num
        self.in_features = in_features
        self.configs = copy.deepcopy(args.configs)
        self.configs['count'] = args.run
        self.configs['device'] = device
        os.makedirs(self.configs['path'], exist_ok=True)


    def fit(self, graph, labels, train_mask, val_mask, test_mask):
        # model init
        graph = graph.to(self.device)
        labels = labels.to(self.device)
        self.train_mask = train_mask.to(self.device)
        self.valid_mask = val_mask.to(self.device)
        self.test_mask = test_mask.to(self.device)
        self.to(self.device)
        X = graph.ndata["feat"]
        n_nodes, _ = X.shape
        adj = graph.adj(scipy_fmt='csr')
        edge_index = torch.tensor(np.array(adj.nonzero()), device=self.device, dtype=torch.long)

        data = Data(edge_index=edge_index, test_mask=self.test_mask, train_mask=self.train_mask, val_mask=self.valid_mask, x=X, y=labels)


        configs = self.configs
        model_q = GCNNet(self.in_features, int(self.class_num), configs).to(self.device)
        if configs["soft"]:
            data.y = torch.nn.functional.one_hot(data.y.to(torch.long), int(self.class_num))
        
        """initialization of model q"""
        path_for_model_q = configs['path'] + f"/pwem_model_q_{configs['dataset_name']}_{configs['count']}.pt"
        configs['criterion_name'] ='crossEntropy'
        configs['stopping_name'] = 'gcnStoppingCriterion'
        trainer_q = Trainer(model_q, configs, path_for_model_q)
        pretrain_test_acc = 0

        if configs['train_mode'] == 'pre-training' \
                or configs['train_mode'] =='center_redist' \
                or configs['train_mode'] == 'center_edge_rezero' \
                or configs['train_mode'] == 'average_edge_rezero' :
            """pre-train procedure"""

            print("pre-trainning procedure")
            trainer_q.run(data, data.train_mask, data.y)
            pretrain_test_acc = trainer_q.test_acc(data, data.y, False)
            val_acc = trainer_q.evaluate_acc(data, data.y, False)
            print(f"val acc for pretrain: {val_acc}")
            print(f"test acc for pretrain: {pretrain_test_acc}")
            print(f"best acc: {trainer_q.give_best_test()}")
            configs["init_best_val_acc"] = val_acc

        """EM training procedure"""
        print("EM trianing procedure")

        """initialization of model p"""
        inputs_p = data.x
        targets_p = torch.zeros(data.x.shape[0], self.class_num).to(self.device)
        data_p = Data(edge_index=data.edge_index, test_mask=data.test_mask, train_mask=data.train_mask, val_mask=data.val_mask, x=inputs_p, y=targets_p)

        if configs['train_mode'] != 'joint training':
            q = trainer_q.predict(data)
            data_p.y = q
        data_p.y[data.train_mask] = data.y[data.train_mask].to(torch.float)

        # for 2 step training, use gold is in reverse degrade the performance.

        if configs['fixed'] == True:
            print("set the params in model q requires grade equals false")
            for p in model_q.parameters():
                p.requires_grad = False

        elif configs['train_mode']=='pre-training':
            model_p = PWGNN(model_q, configs).to(self.device)
            configs["criterion_name"] = 'pwLoss'

        elif configs['train_mode'] =='center_redist':
            model_p = PWGNN(model_q, configs).to(self.device)
            configs["criterion_name"] = 'PWLoss_redis2'

        elif configs['train_mode'] == 'joint_redist':
            model_p  = PWGNN(model_q, configs).to(self.device)
            configs["criterion_name"] = 'PWLoss_redis2'

        elif configs['train_mode'] == 'center_edge_rezero':
            configs['rezero_size'] = data.edge_index.shape[1]
            print(configs['rezero_size'])
            model_p = PWGNN(model_q, configs).to(self.device)
            configs["criterion_name"] = 'PWLoss_redist2_edge_rezero'

        elif configs['train_mode'] == 'average_edge_rezero':
            configs['rezero_size'] = data.edge_index.shape[1]
            print(configs['rezero_size'])
            model_p = PWGNN(model_q, configs).to(self.device)
            configs['criterion_name'] = 'PWLoss_average_edge_rezero'


        """ change the weight decay for two procedure"""
        if configs["change_lr"] == True:
            configs["weight_decay"] = 5e-5
            configs["learning_rate"] = 0.01


        path_for_model_p = configs['path']+f"/pwem_model_p_{configs['dataset_name']}_{configs['count']}.pt"

        configs["stopping_name"] = 'pwemStoppingCriterionForEM'
        # configs["stopping_name"] = 'pwemStoppingCriterion'

        configs['patience'] = 10
        if configs['change_EMPatience']: configs['patience']=5

        trainer_pwem = Trainer_PWEM(model_p, configs, path_for_model_p)
        row,_ = data_p.edge_index
        deg = degree(row, data.x.size(0), dtype=data.x.dtype).view(-1, 1) + 1 # here already add self-loop

        num_m_steps = data.x.shape[0]//30 + 1
        num_e_steps = 10
        configs['num_m_steps'] = num_m_steps
        configs['num_e_steps'] = num_e_steps

        """call EM training precedure"""
        q = trainer_pwem.train(data_p, data.y, deg, configs, num_m_steps, num_e_steps)
        final_EM_test_acc = trainer_pwem.test_acc(data_p.y, data.y, data_p.test_mask, False)
        print(final_EM_test_acc)
        self.final_EM_test_acc = final_EM_test_acc
        self.C = data_p.y


    def predict(self, graph):
        y_pred = torch.argmax(self.C, dim=1)

        return y_pred.cpu(), self.C.cpu(), None






class GCNNet(torch.nn.Module):
    def __init__(self,in_channels, num_classes, configs):
        super(GCNNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.conv1 = GCNConv(in_channels, configs['hidden_size'])
        self.conv2 = GCNConv(configs['hidden_size'], num_classes)
        self.configs = configs

    def reset_parameters(selfs):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = F.dropout(x,self.configs['dropout_prob'], training=self.training)
        x = self.conv1(x, edge_index)
        # x = self.norm(x)
        x = F.relu(x)
        x = F.dropout(x,self.configs['dropout_prob'], training=self.training)
        x = self.conv2(x, edge_index)

        return x


def accuracy(input, target, idx, verbose=True):
    _,pred = input.max(dim=1)
    correct = float(pred[idx].eq(target[idx]).sum().item())
    acc = correct / idx.sum().item()
    if verbose: print('Accuracy: {:.4f}'.format(acc))
    return acc


def accuracy_soft(input, target, idx, verbose = True):
    _, pred = input.max(dim=1)
    pred_one_hot = torch.nn.functional.one_hot(pred, target.shape[-1]).float()
    correct = float((pred_one_hot[idx]*target[idx]).sum().item())
    acc = correct/idx.sum().item()
    if verbose: print(f"Accuracy: {acc}")
    return acc

def get_optimizer(optimizer_name, parameters, lr, weight_decay):
    if optimizer_name == "adam": return torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)

def get_criterion(criterion_name):
    if criterion_name == "crossEntropy": return nn.CrossEntropyLoss()
    if criterion_name == "pwLoss": return PWLoss()
    if criterion_name == "pwLoss_J": return PWLoss_J()
    if criterion_name == "PWLoss_redis2": return PWLoss_redis2()
    if criterion_name == "PWLoss_redist2_edge_rezero": return PWLoss_redis2_edge_rezero()
    if criterion_name == "PWLoss_average_edge_rezero": return PWLoss_edge_rezero()
    if criterion_name == "PWLoss_average_mixture": return PWLoss_average_mixture()
    if criterion_name == "PWLoss_center_mixture": return PWLoss_center_mixture()


def get_metric(metric_name, soft):
    if metric_name == 'accuracy':
        if soft: return accuracy_soft
        else: return accuracy

def get_early_stopping_criterion(stopping_name,patience, path, init_best_val_acc):
    if stopping_name == 'gcnStoppingCriterion': return GCNEarlyStoppingCriterion(patience=patience, path=path, init_best_val_acc=init_best_val_acc)
    elif stopping_name == 'pwemStoppingCriterion': return PWEMEarlyStoppingCriterion(patience=patience, path=path, init_best_val_acc=init_best_val_acc)
    elif stopping_name == 'pwemStoppingCriterionForEM': return PWEMEarlyStoppingCriterionForM(patience=patience, path=path, init_best_val_acc=init_best_val_acc)



class Trainer(object):
    torch.autograd.set_detect_anomaly(True)
    def __init__(self, model, configs, path):

        optimizezr_name = configs['optimizer_name']
        lr = configs['learning_rate']
        weight_decay = configs['weight_decay']
        parameters = [p for p in model.parameters() if p.requires_grad]
        stopping_name = configs['stopping_name']
        patience = configs['patience']
        self.path = path
        self.optimizer = get_optimizer(optimizezr_name, parameters, lr, weight_decay)
        self.model = model
        self.criterion = get_criterion(configs['criterion_name'])
        self.configs = configs
        self.metric = get_metric(configs['metric'],configs['soft'])
        self.early_stopping_criterion = get_early_stopping_criterion(stopping_name, patience, path, configs['init_best_val_acc'])
        
        self.records = {
            "eval_acc": torch.tensor([]),
            "test_acc": torch.tensor([]),
            "train_loss": torch.tensor([]),
            "eval_loss": torch.tensor([]),
            "configs": configs}

        self.criterion.to(configs['device'])


    def run(self, data, mask, target):
        if self.configs['soft']:
            for i in range(self.configs['num_run']):
                loss = self.update_soft(data, mask)
                val_loss = self.evaluate_loss_soft(data, target)
                val_acc = self.evaluate_acc(data, target, self.configs['verbose'])
                test_acc = self.test_acc(data, target, self.configs['verbose'])
                if self.configs['verbose']:print(f"train_loss:{loss}, eval_acc:{val_acc}, test_acc:{test_acc}")
                self.records['eval_acc'] = torch.cat((self.records['eval_acc'],torch.tensor([val_acc])),dim=0)
                self.records['test_acc'] = torch.cat((self.records['test_acc'],torch.tensor([test_acc])),dim=0)
                self.records['train_loss'] = torch.cat((self.records['train_loss'],torch.tensor([loss])),dim=0)
                self.records['eval_loss'] = torch.cat((self.records['eval_loss'],torch.tensor([val_loss])),dim=0)

                if self.early_stopping_criterion.should_stop(epoch=i, val_accuracy=val_acc, model=self.model, optimizer=self.optimizer):
                    return self.early_stopping_criterion.after_stopping_ops(self.model, self.optimizer)

        else:
            for i in range(self.configs['num_run']):
                loss = self.update(data, mask)
                val_loss = self.evaluate_loss(data, target)
                val_acc = self.evaluate_acc(data, target, self.configs['verbose'])
                test_acc = self.test_acc(data, target, self.configs['verbose'])
                if self.configs['verbose']:print(f"train_loss:{loss}, eval_acc:{val_acc}, test_acc:{test_acc}")
                self.records['eval_acc'] = torch.cat((self.records['eval_acc'], torch.tensor([val_acc])), dim=0)
                self.records['test_acc'] = torch.cat((self.records['test_acc'], torch.tensor([test_acc])), dim=0)
                self.records['train_loss'] = torch.cat((self.records['train_loss'], torch.tensor([loss])), dim=0)
                self.records['eval_loss'] = torch.cat((self.records['eval_loss'], torch.tensor([val_loss])), dim=0)

                if self.early_stopping_criterion.should_stop(epoch=i, val_accuracy=val_acc, model=self.model, optimizer=self.optimizer):
                    return self.early_stopping_criterion.after_stopping_ops(self.model, self.optimizer)



    def update(self,data, mask):
        self.model.train()
        self.optimizer.zero_grad()

        out = self.model(data)
        loss = self.criterion(out[mask], data.y[mask])
        loss.backward(retain_graph=True)
        self.optimizer.step()

        return loss.item()

    def update_soft(self, data, mask):

        self.model.train()
        self.optimizer.zero_grad()

        out = self.model(data)
        logit = torch.log_softmax(out, dim = -1)
        loss = -torch.mean(torch.sum(data.y[mask]*logit[mask], dim = -1))
        loss.backward(retain_graph=True)
        self.optimizer.step()

        return loss

    def predict(self, data):
        self.model.eval()
        with torch.no_grad():
            out = self.model(data)
            probs = torch.softmax(out, dim=-1)
        return probs

    def evaluate_acc(self,data, target, verbose):
        self.model.eval()
        with torch.no_grad():
            result = self.metric(self.model(data), target, data.val_mask, verbose=verbose)
        return result

    def evaluate_loss_soft(self, data, target):
        self.model.eval()
        with torch.no_grad():
            out = self.model(data)
            logit = torch.log_softmax(out, dim=-1)
            loss = -torch.mean(torch.sum(target[data.val_mask] * logit[data.val_mask], dim=-1))
        return loss.item()

    def evaluate_loss(self, data, target):
        self.model.eval()
        with torch.no_grad():
            out = self.model(data)
            loss = self.criterion(out[data.val_mask], target[data.val_mask])
        return loss.item()


    def test_acc(self, data, target, verbose):
        self.model.eval()
        with torch.no_grad():
            result = self.metric(self.model(data), target, data.test_mask, verbose=verbose)
        return result


    def reset(self): return 0

    def save(self, path_for_results):
        # should append current records to the end of file
        for k, v in self.records.items():
            if isinstance(v, torch.Tensor):
                self.records[k] = v.cpu().tolist()
        with open(path_for_results, 'w') as f:
            json.dump(self.records, f)


    def give_best_test(self):
        # similar as gmnn original code
        best_eval_acc =0
        best_epoch = 0
        for i, acc in enumerate(self.records['eval_acc']):
            if acc>best_eval_acc:
                best_eval_acc=acc
                best_epoch=i
        self.records.update({'best_acc_gmnn_criterion':self.records['test_acc'][best_epoch]})
        return self.records['test_acc'][best_epoch]


class BPLeafToRoot(MessagePassing):
    def __init__(self):
        super(BPLeafToRoot, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, x_redistributed, edge_index, binary_redistributed):
        # here to solve the problem of overcounting, we forward distributed unary and binary energy
        return self.propagate(edge_index, size=(x_redistributed.size(0), x_redistributed.size(0)), x=x_redistributed, binary=binary_redistributed)

    def message(self, x_j, binary):

        N, C = x_j.shape
        messages = torch.logsumexp(x_j.view(N, -1, 1) + binary, axis=1)
        return messages

    def update(self, aggr_out,x):

        log_z = torch.logsumexp((x + aggr_out),axis=1)
        # normalizer for every piece, i.e. for every node.
        return log_z

class BPLeafToRoot_edge_rezero(MessagePassing):
    # average redistribution + edge rezero
    def __init__(self):
        # super(BPLeafToRoot_edge_rezero, self).__init__(aggr='add', flow="source_to_target")
        super(BPLeafToRoot_edge_rezero, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, x_redistributed, edge_index, binary_redistributed, rezero):
        # here to solve the problem of overcounting, we forward distributed unary and binary energy
        return self.propagate(edge_index, size=(x_redistributed.size(0), x_redistributed.size(0)), x=x_redistributed, binary=binary_redistributed, rezero=rezero)

    def message(self, x_j, binary, rezero):

        N, C = x_j.shape
        messages = torch.logsumexp( (x_j.view(N, -1, 1) + rezero.view(-1, 1, 1) * binary), axis=1)
        return messages

    def update(self, aggr_out,x):

        log_z = torch.logsumexp((x + aggr_out),axis=1)
        # normalizer for every piece, i.e. for every node.
        return log_z

class BPLeafToRoot_center(MessagePassing):
    def __init__(self):
        super(BPLeafToRoot_center, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, unary, edge_index, binary_redistributed):
        return self.propagate(edge_index, size=(unary.size(0), unary.size(0)), x=unary, binary=binary_redistributed)

    def message(self, x_j, binary):

        N, C = x_j.shape
        x_j = x_j*0
        messages = torch.logsumexp(x_j.view(N, -1, 1)+binary, axis=1)
        return messages

    def update(self, aggr_out,x):

        log_z = torch.logsumexp((x + aggr_out),axis=1)
        # normalizer for every piece, i.e. for every node.
        return log_z


class BPLeafToRoot_center_edge_rezero(MessagePassing):
    def __init__(self):
        # super(BPLeafToRoot_center_edge_rezero, self).__init__(aggr='add', flow="source_to_target")
        super(BPLeafToRoot_center_edge_rezero, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, unary, edge_index, binary_redistributed, rezero):
        return self.propagate(edge_index, size=(unary.size(0), unary.size(0)), x=unary, binary=binary_redistributed, rezero=rezero)

    def message(self, x_j, binary, rezero):

        N, C = x_j.shape
        x_j = x_j*0
        messages = torch.logsumexp(rezero.view(-1, 1, 1) * (x_j.view(N, -1, 1) + binary), axis=1)
        return messages

    def update(self, aggr_out,x):

        log_z = torch.logsumexp((x + aggr_out),axis=1)
        # normalizer for every piece, i.e. for every node.
        return log_z


class PWLoss(MessagePassing):
    # average redistribution loss
    def __init__(self):
        super(PWLoss, self).__init__(aggr='add', flow="target_to_source")
        # super(PWLoss, self).__init__(aggr='add', flow="source_to_target")

    def forward(self, x_redistributed, edge_index, binary_redistributed, log_z_redistributed, q):
        # forward params needed to construct message and update

        return self.propagate(edge_index, size=(x_redistributed.size(0), x_redistributed.size(0)), x=x_redistributed, binary=binary_redistributed, log_z=log_z_redistributed, q=q,
                              edge_index_params=edge_index)

    def message(self, x_j, edge_index_params, binary, q):

        i,j = edge_index_params
        q_j,q_i =q[j],q[i] # q_j and q_i are of shape E*C
        messages = torch.sum(x_j*q_j,axis=1) + torch.sum(torch.mm(q_i,binary)*q_j, axis=1)
        return messages.view(-1,1)

    def update(self, aggr_out, x, log_z, q):
        # return the loss for every piece, and final loss need the summation
        result = torch.sum(x*q,axis =1) + aggr_out.squeeze() - log_z
        return result

class PWLoss_edge_rezero(MessagePassing):
    # average redistribution version + edge_rezero
    def __init__(self):
        # super(PWLoss_edge_rezero, self).__init__(aggr='add', flow="source_to_target")
        super(PWLoss_edge_rezero, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, x_redistributed, edge_index, binary_redistributed, log_z_redistributed, q, rezero):
        # forward params needed to construct message and update
        return self.propagate(edge_index, size=(x_redistributed.size(0), x_redistributed.size(0)), x=x_redistributed, binary=binary_redistributed, log_z=log_z_redistributed, q=q,
                              edge_index_params=edge_index, rezero=rezero)

    def message(self, x_j, edge_index_params, binary, q, rezero):

        i,j = edge_index_params
        q_j,q_i =q[j],q[i] # q_j and q_i are of shape E*C
        messages = torch.sum(x_j*q_j,axis=1) + rezero * torch.sum(torch.mm(q_i,binary)*q_j, axis=1)
        return messages.view(-1,1)

    def update(self, aggr_out, x, log_z, q):
        # return the loss for every piece, and final loss need the summation
        result = torch.sum(x*q,axis =1) + aggr_out.squeeze() - log_z
        return result


class PWLoss_redis2(MessagePassing):
    # center redistribution
    def __init__(self):
        super(PWLoss_redis2, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, unary, edge_index, binary_redistributed, log_z_redistributed, q):
        # forward params needed to construct message and update

        return self.propagate(edge_index, size=(unary.size(0), unary.size(0)), x=unary, binary=binary_redistributed, log_z=log_z_redistributed, q=q,
                              edge_index_params=edge_index)

    def message(self, x_j, edge_index_params, binary, q):
        i,j = edge_index_params
        q_j,q_i =q[j],q[i] # q_j and q_i are of shape E*C
        messages = torch.sum(torch.mm(q_i,binary)*q_j, axis=1)
        return messages.view(-1,1)

    def update(self, aggr_out, x, log_z, q):
        # return the loss for every piece, and final loss need the summation
        result = torch.sum(x*q,axis =1) + aggr_out.squeeze() - log_z
        return result

class PWLoss_redis2_edge_rezero(MessagePassing):
    # center redistribution + edge rezero
    def __init__(self):
        super(PWLoss_redis2_edge_rezero, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, unary, edge_index, binary_redistributed, log_z_redistributed, q, rezero):
        # forward params needed to construct message and update

        return self.propagate(edge_index, size=(unary.size(0), unary.size(0)), x=unary, binary=binary_redistributed, log_z=log_z_redistributed, q=q,
                              edge_index_params=edge_index, rezero=rezero)

    def message(self, x_j, edge_index_params, binary, q, rezero):
        i,j = edge_index_params
        q_j,q_i =q[j],q[i] # q_j and q_i are of shape E*C
        messages = rezero* torch.sum(torch.mm(q_i,binary)*q_j, axis=1)
        return messages.view(-1,1)

    def update(self, aggr_out, x, log_z, q):
        # return the loss for every piece, and final loss need the summation
        result = torch.sum(x*q,axis =1) + aggr_out.squeeze() - log_z
        return result

class PWLoss_average_mixture(MessagePassing):
    # average redistribution loss
    def __init__(self):
        super(PWLoss_average_mixture, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, x_redistributed, edge_index, binary_redistributed, log_z_redistributed, q, unary, rezero):

        return self.propagate(edge_index, size=(x_redistributed.size(0), x_redistributed.size(0)), x=x_redistributed, binary=binary_redistributed, log_z=log_z_redistributed, q=q,
                              edge_index_params=edge_index, unary=unary, rezero=rezero)

    def message(self, x_j, edge_index_params, binary, q):

        i,j = edge_index_params
        q_j,q_i =q[j],q[i] # q_j and q_i are of shape E*C
        messages = torch.sum(x_j*q_j,axis=1) + torch.sum(torch.mm(q_i,binary)*q_j, axis=1)
        return messages.view(-1,1)

    def update(self, aggr_out, x, log_z, q, unary, rezero):
        piecewise_loss = torch.sum(x*q,axis =1) + aggr_out.squeeze() - log_z

        logit = torch.log_softmax(unary, dim=-1)
        entropy_loss = torch.sum(q*logit, dim = -1)
        gamma = torch.nn.functional.sigmoid(rezero)
        result= (1-gamma)*entropy_loss + gamma * piecewise_loss
        return result


class PWLoss_center_mixture(MessagePassing):
    # center redistribution
    def __init__(self):
        super(PWLoss_center_mixture, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, unary, edge_index, binary_redistributed, log_z_redistributed, q, rezero):
        # forward params needed to construct message and update

        return self.propagate(edge_index, size=(unary.size(0), unary.size(0)), x=unary, binary=binary_redistributed, log_z=log_z_redistributed, q=q,
                              edge_index_params=edge_index, rezero=rezero)

    def message(self, x_j, edge_index_params, binary, q):

        i,j = edge_index_params
        q_j,q_i =q[j],q[i] # q_j and q_i are of shape E*C
        messages = torch.sum(torch.mm(q_i,binary)*q_j, axis=1)
        return messages.view(-1,1)

    def update(self, aggr_out, x, log_z, q, rezero):
        # return the loss for every piece, and final loss need the summation
        piecewise_loss = torch.sum(x*q,axis =1) + aggr_out.squeeze() - log_z
        logit = torch.log_softmax(x, dim=-1)
        entropy_loss = torch.sum(q * logit, dim=-1)
        gamma = torch.nn.functional.sigmoid(rezero)
        result = (1 - gamma) * entropy_loss + gamma * piecewise_loss
        return result

class MFUpdate(MessagePassing):
    # update q
    def __init__(self):
        super(MFUpdate, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, q, edge_index, binary, unary):
        # here x represent the q
        return self.propagate(edge_index, size=(q.size(0), q.size(0)), x=q, binary=binary, unary=unary,
                              edge_index_params=edge_index)

    def message(self, x_j, binary):
        messages = torch.mm(x_j, binary)
        return messages

    def update(self, aggr_out, unary):
        # return the loss for every piece, and final loss need the summation
        return F.softmax(unary + aggr_out, dim=1)

class MFUpdate_edge_rezero(MessagePassing):
    # edge rezero inference
    # update q
    def __init__(self):
        super(MFUpdate_edge_rezero, self).__init__(aggr='add', flow="target_to_source")

    def forward(self, q, edge_index, binary, unary, rezero):
        # here x represent the q
        return self.propagate(edge_index, size=(q.size(0), q.size(0)), x=q, binary=binary, unary=unary,
                              edge_index_params=edge_index, rezero=rezero)

    def message(self, x_j, binary, rezero):
        messages = rezero.view(-1,1) * torch.mm(x_j, binary)
        return messages

    def update(self, aggr_out, unary):
        # return the loss for every piece, and final loss need the summation
        return F.softmax(unary + aggr_out, dim=1)


class PWGNN(nn.Module):
    def __init__(self, gnnbackbone, configs):
        super(PWGNN, self).__init__()

        # # init backbone
        self.gnnbackbone = gnnbackbone
        num_classes = gnnbackbone.num_classes
        self.configs = configs
        self.inf = MFUpdate()

        if configs['redistribution'] =="center": self.up_BP = BPLeafToRoot_center()
        elif configs['redistribution'] == "average": self.up_BP = BPLeafToRoot()
        elif configs['redistribution'] == "average_edge_rezero":
            self.up_BP = BPLeafToRoot_edge_rezero()
            self.inf = MFUpdate_edge_rezero()
        elif configs['redistribution'] == "center_edge_rezero":
            self.up_BP = BPLeafToRoot_center_edge_rezero()
            self.inf = MFUpdate_edge_rezero()

        self.Binary = torch.nn.Parameter((torch.randn(num_classes, num_classes) + torch.eye(num_classes))/num_classes,requires_grad=True)  # identity between every piece

        if self.configs['rezero']:
            self.rezero_coefficients = torch.nn.Parameter(torch.zeros(self.configs['rezero_size']), requires_grad= self.configs['rezero_require_grad'])

    def forward(self, data, deg):
        x, edge_index = data.x, data.edge_index
        Unary = self.gnnbackbone(data)
        Binary = (self.Binary + self.Binary.T)/2
        if self.configs['rezero']:
            if self.configs['rezero_change'] == 'auto':
                rezero = self.rezero_coefficients
            elif self.configs['rezero_change'] == 'linear':
                rezero = self.configs['rezero_value'] + 0.01
                self.configs['rezero_value'] = rezero
                rezero = rezero.to(self.configs['device'])
            elif self.configs['rezero_change'] == 'pairwise':
                rezero = self.rezero_coefficients
                Binary = rezero * Binary

        if self.configs['redistribution']=='average':
            log_z_reditributed = self.up_BP(Unary/(deg), edge_index, Binary/2)
            return Unary, Unary / (deg), Binary, Binary / 2, log_z_reditributed

        elif self.configs['redistribution'] =='average_edge_rezero':
            log_z_reditributed = self.up_BP(Unary/(deg), edge_index, Binary/2, rezero)
            return Unary, Unary/(deg), Binary, Binary/2, log_z_reditributed, rezero

        elif self.configs['redistribution']=='center':
            log_z_reditributed = self.up_BP(Unary, edge_index, Binary/2)
            return Unary, Binary, log_z_reditributed

        elif self.configs['redistribution'] == 'center_edge_rezero':
            log_z_reditributed = self.up_BP(Unary, edge_index, Binary/2, rezero)
            return  Unary,  Binary, log_z_reditributed, rezero

        else: assert 'no such redistribution version'

    def inference(self,data, deg):
        self.eval()
        with torch.no_grad():
            # need to check whether B need to force to be simatry
            if self.configs['redistribution'] == 'center':
                U, B, _ = self.forward(data, deg)
                q = self.inf(data.y, data.edge_index, B, U)
            elif self.configs['redistribution'] == 'center_edge_rezero':
                U, B, Z, R = self.forward(data, deg)
                q = self.inf(data.y, data.edge_index, B, U, R)
            elif self.configs['redistribution'] == 'average_edge_rezero':
                U, U_redistributed, B, B_redistributed, Z, R = self.forward(data, deg)
                q = self.inf(data.y, data.edge_index, B, U, R)
            else:
                U,U_redistributed, B, B_redistributed, _ = self.forward(data,deg) # we don't want log z since it is only sorrugate approximation.
                q = self.inf(data.y, data.edge_index, B, U)

        return q


class Trainer_PWEM(Trainer):
    def __init__(self, model, configs, path):
        super().__init__(model, configs, path)
        self.records = {"eval_acc": torch.tensor([]),
                        "test_acc": torch.tensor([]),
                        "train_loss": torch.tensor([]),
                        "eval_loss": torch.tensor([]),
                        "configs": configs}

        self.train_mode = configs['train_mode']


    def train(self, data, target, deg, configs, num_m_steps, num_e_steps):

        torch.save(data.y, self.path + 'q.tensor')
        torch.save(self.model.state_dict(), self.path)
        torch.save(self.optimizer.state_dict(), self.path + '.opt')

        if configs['stopping_name'] == 'pwemStoppingCriterionForEM':
            self.early_stopping_for_M = PWEMEarlyStoppingCriterionForM(patience=50, path=self.path, init_best_val_acc=configs['init_best_val_acc'])
            self.early_stopping_for_E = PWEMEarlyStoppingCriterionForE(patience=5, path=self.path, init_best_val_acc=configs['init_best_val_acc'])


            for i in range(configs['num_EM_steps']):
                loss, val_loss, val_acc, test_acc = self.run_with_early_stopping_in_EM(data, target, deg, num_m_steps, num_e_steps)

                print(f"number of EM step: {i}, train_loss: {loss:.2f}, val_loss: {val_loss:.2f}, val_acc: {val_acc:.4f}, test_acc: {test_acc:.4f}")
                if self.early_stopping_criterion.should_stop(epoch=i, val_accuracy=val_acc, model=self.model,
                                                             optimizer=self.optimizer, q=data.y):
                    return self.early_stopping_criterion.after_stopping_ops(self.model, self.optimizer)

        else:
            for i in range(configs['num_EM_steps']):
                loss, val_loss, val_acc, test_acc = self.run_without_EM(data, target, deg)
                print(f"number of EM step: {i}, train_loss: {loss}, val_loss: {val_loss}, val_acc: {val_acc}, test_acc: {test_acc}")
                if self.early_stopping_criterion.should_stop(epoch=i, val_accuracy=val_acc, model=self.model, optimizer=self.optimizer, q=data.y):
                    return self.early_stopping_criterion.after_stopping_ops()


    def run(self, data, target, deg, num_m_steps, num_e_steps):
        for i in range(num_m_steps):
            if self.train_mode == "joint training":
                loss = self.update_j(data)
                val_loss = self.evaluate_loss_j(data, data.val_mask)
            else:
                loss = self.update(data, deg)
                val_loss = self.evaluate_loss(data, deg, data.val_mask)

            self.records['train_loss'] = torch.cat((self.records['train_loss'], torch.tensor([loss])), dim=0)
            self.records['eval_loss'] = torch.cat((self.records['eval_loss'], torch.tensor([val_loss])), dim=0)

        for i in range(num_e_steps):
            if self.configs['mean_field'] == False:

                with torch.no_grad(): # only use node representation to do the inference test
                    val_acc = self.evaluate_acc(torch.softmax(self.model.gnnbackbone(data), dim=-1), target, data.val_mask, self.configs['verbose'])
                    test_acc = self.evaluate_acc(torch.softmax(self.model.gnnbackbone(data), dim=-1), target, data.test_mask, self.configs['verbose'])
            else:
                with torch.no_grad():
                    data.y[~data.train_mask] = self.model.inference(data)[~data.train_mask]
                    val_acc = self.evaluate_acc(data.y, target, data.val_mask, self.configs['verbose'])
                    test_acc = self.evaluate_acc(data.y, target, data.test_mask, self.configs['verbose'])
            self.records['eval_acc'] = torch.cat((self.records['eval_acc'], torch.tensor([val_acc])), dim=0)
            self.records['test_acc'] = torch.cat((self.records['test_acc'], torch.tensor([test_acc])), dim=0)

        return loss, val_loss, val_acc, test_acc

    def run_with_early_stopping_in_EM(self, data, target, deg, num_m_steps, num_e_steps):
        # print("run_with_early_stopping_in_EM")
        pre_trained_steps = self.configs['pre_trained_steps']

        for i in range(num_m_steps):
            self.epoch = pre_trained_steps+i
            self.model.train()

            if self.train_mode == "pre-training":
                loss = self.update(data, deg)
                val_loss = self.evaluate_loss(data, deg, data.val_mask)
            elif self.train_mode == "center_redist":
                loss = self.update_redist2(data,deg)
                val_loss = self.evaluate_loss_redist2(data, deg, data.val_mask)
            elif self.train_mode == "joint_redist":
                loss = self.update_redist2(data, deg)
                val_loss = self.evaluate_loss_redist2(data, deg, data.val_mask)
            elif self.train_mode == "center_edge_rezero":
                loss = self.update_redist2_edge_rezero(data,deg)
                val_loss = self.evaluate_loss_edge_rezero(data, deg, data.val_mask)
            elif self.train_mode == "average_edge_rezero":
                loss = self.update_edge_rezero(data, deg)
                val_loss = self.evaluate_loss_average_edge_rezero(data, deg, data.val_mask)


            self.records['train_loss'] = torch.cat((self.records['train_loss'], torch.tensor([loss])), dim=0)
            self.records['eval_loss'] = torch.cat((self.records['eval_loss'], torch.tensor([val_loss])), dim=0)

            test_temp_q = self.model.inference(data, deg)
            val_acc = self.evaluate_acc(test_temp_q, target, data.val_mask, self.configs['verbose'])
            test_acc = self.evaluate_acc(test_temp_q, target, data.test_mask, self.configs['verbose'])
            if self.early_stopping_for_M.should_stop(epoch=i, val_accuracy=val_acc, model=self.model,
                                                     optimizer=self.optimizer, q=data.y):
                self.configs['pre_trained_steps'] = pre_trained_steps+i+1
                break

        self.early_stopping_for_M.count = 0
        self.model.load_state_dict(torch.load(self.path))
        self.optimizer.load_state_dict(torch.load(self.path + '.opt'))


        for i in range(num_e_steps):
            self.model.eval()
            if self.configs['mean_field'] == False:
                with torch.no_grad(): # only use node representation to do the inference test
                    val_acc = self.evaluate_acc(torch.softmax(self.model.gnnbackbone(data), dim=-1), target, data.val_mask, self.configs['verbose'])
                    test_acc = self.evaluate_acc(torch.softmax(self.model.gnnbackbone(data), dim=-1), target, data.test_mask, self.configs['verbose'])
            else:
                with torch.no_grad():
                    data.y[~data.train_mask] = self.model.inference(data, deg)[~data.train_mask]
                    val_acc = self.evaluate_acc(data.y, target, data.val_mask, self.configs['verbose'])
                    test_acc = self.evaluate_acc(data.y, target, data.test_mask, self.configs['verbose'])

            self.records['eval_acc'] = torch.cat((self.records['eval_acc'], torch.tensor([val_acc])), dim=0)
            self.records['test_acc'] = torch.cat((self.records['test_acc'], torch.tensor([test_acc])), dim=0)
            if self.early_stopping_for_E.should_stop(epoch=i, val_accuracy=val_acc, q=data.y):break


        self.early_stopping_for_E.count = 0
        data.y[~data.train_mask] = self.early_stopping_for_E.after_stopping_ops()[~data.train_mask]
        val_acc = self.evaluate_acc(data.y, target, data.val_mask, self.configs['verbose'])
        test_acc = self.evaluate_acc(data.y, target, data.test_mask, self.configs['verbose'])

        return loss, val_loss, val_acc, test_acc



    def update(self, data, deg): # update function for pre-training
        # m step function
        self.model.train()
        self.optimizer.zero_grad()

        U, U_redistributed, B, B_redistributed, Z = self.model(data, deg)
        loss_for_pieces = self.criterion(U_redistributed, data.edge_index, B_redistributed, Z, data.y)
        loss = -torch.mean(loss_for_pieces)
        loss.backward()
        self.optimizer.step()

        return loss


    def update_use_train_data(self, data, deg): #update function for pre-pwem
        self.model.train()
        self.optimizer.zero_grad()
        U, U_redistributed, B, B_redistributed, Z = self.model(data, deg)
        loss_for_pieces = self.criterion(U_redistributed, data.edge_index, B_redistributed, Z, data.y)
        loss = -torch.mean(loss_for_pieces[data.train_mask])
        loss.backward()
        self.optimizer.step()

        return loss

    def update_redist2(self, data, deg):
        self.model.train()
        self.optimizer.zero_grad()
        U, B, Z = self.model(data, deg)
        loss_for_pieces = self.criterion(U, data.edge_index, B/2, Z, data.y)
        loss = -torch.mean(loss_for_pieces)
        loss.backward()
        self.optimizer.step()
        return loss


    def update_redist2_edge_rezero(self, data,deg):
        self.model.train()
        self.optimizer.zero_grad()
        U, B, Z, R= self.model(data, deg)
        loss_for_pieces = self.criterion(U, data.edge_index, B/2, Z, data.y, R)
        loss = -torch.mean(loss_for_pieces)
        loss.backward()
        self.optimizer.step()

        return loss

    def update_edge_rezero(self, data,deg):
        self.model.train()
        self.optimizer.zero_grad()
        U, U_redistribution, B, B_redistribution, Z, R= self.model(data, deg)
        loss_for_pieces = self.criterion(U_redistribution, data.edge_index, B_redistribution, Z, data.y, R)
        loss = -torch.mean(loss_for_pieces)
        loss.backward()
        self.optimizer.step()
        return loss


    def evaluate_loss(self, data, deg, mask):
        self.model.eval()
        with torch.no_grad():
            U, U_redistributed, B, B_redistributed, Z = self.model(data, deg)
            loss_for_pieces = self.criterion(U_redistributed, data.edge_index, B_redistributed, Z, data.y)
            loss = -torch.mean(loss_for_pieces[mask])
            return loss

    def evaluate_loss_redist2(self, data, deg,mask):
        self.model.eval()
        with torch.no_grad():
            U, B, log_z_redistributed = self.model(data,deg)
            loss_for_pieces = self.criterion(U, data.edge_index, B/2, log_z_redistributed,
                                             data.y)
            loss = -torch.mean(loss_for_pieces[mask])
            return loss

    def evaluate_loss_edge_rezero(self, data, deg, mask):
        self.model.eval()
        with torch.no_grad():
            U, B, log_z_redistributed, R = self.model(data,deg)
            loss_for_pieces = self.criterion(U, data.edge_index, B/2, log_z_redistributed,
                                             data.y, R)
            loss = -torch.mean(loss_for_pieces[mask])
            return loss

    def evaluate_loss_average_edge_rezero(self, data, deg, mask):
        self.model.eval()
        with torch.no_grad():
            U, U_redistribution, B, B_redistribution, log_z_redistributed, R = self.model(data,deg)
            loss_for_pieces = self.criterion(U_redistribution, data.edge_index, B_redistribution, log_z_redistributed,
                                             data.y, R)
            loss = -torch.mean(loss_for_pieces[mask])
            return loss

    def evaluate_acc(self, input, target, mask, verbose):
        with torch.no_grad():
            acc = self.metric(input, target, mask, verbose)
        return acc

    def test_acc(self, input, target, mask, verbose):
        with torch.no_grad():
            acc = self.metric(input, target, mask, verbose)
        return acc


class EarlyStoppingCriterion(object):
    def __init__(self, patience, path, init_best_val_acc):
        self.patience = patience
        self.path = path
        self.best_val_acc = init_best_val_acc
    def should_stop(self, epoch, val_loss, val_accuracy, model):
        raise NotImplementedError

    def after_stopping_ops(self):
        raise NotImplementedError

    def reset(self):
        raise NotImplementedError



class GCNEarlyStoppingCriterion(EarlyStoppingCriterion):

    def __init__(self, patience, path, init_best_val_acc):
        super().__init__(patience, path, init_best_val_acc)
        self.val_acces = []
        self.best_epoch = 0
        self.count = 0

    def should_stop(self, epoch, val_accuracy, model, optimizer):

        self.val_acces.append(val_accuracy)
        if val_accuracy > self.best_val_acc:
            # store the model, optimizer and q of the highest accuracy
            self.best_epoch = epoch
            self.best_val_acc = val_accuracy
            torch.save(model.state_dict(), self.path)
            torch.save(optimizer.state_dict(), self.path + '.opt')
            self.count = 0
        else: self.count = self.count + 1
        return (self.count >= self.patience and self.val_acces[-1]<= np.mean(self.val_acces[-(self.patience + 1):-1])) or self.count>100

    def after_stopping_ops(self, model, optimizer):
        # after stop, reload the corresponding best model, optimizer and q
        return model.load_state_dict(torch.load(self.path)), optimizer.load_state_dict(torch.load(self.path + '.opt'))

    def reset(self):
        self.val_acces = []

class PWEMEarlyStoppingCriterion(EarlyStoppingCriterion):

    def __init__(self, patience, path, init_best_val_acc):
        super().__init__(patience, path, init_best_val_acc)
        self.val_acces = []
        self.best_epoch = 0
        self.count = 0
        print(self.patience)
    def should_stop(self, epoch, val_accuracy, model, optimizer, q):
        # if evaluate accuracy smaller than mean of previous patience times' evaluate accuracy and the count is bigger than the patience, return true

        self.val_acces.append(val_accuracy)
        if val_accuracy > self.best_val_acc:
            # store the model, optimizer and q of the highest accuracy
            self.best_epoch = epoch
            self.best_val_acc = val_accuracy
            torch.save(model.state_dict(), self.path)
            torch.save(optimizer.state_dict(), self.path + '.opt')
            torch.save(q, self.path + 'q.tensor')
            print(self.path + 'q.tensor')
            self.count = 0
        else: self.count = self.count + 1
        return (self.count >= self.patience and self.val_acces[-1]<= torch.mean(torch.tensor(self.val_acces[-(self.patience + 1):-1]))) or self.count>100

    def after_stopping_ops(self):
        # we only need to load q for test
        return torch.load(self.path + 'q.tensor')

    def reset(self):
        self.val_acces = []

class PWEMEarlyStoppingCriterionForM(EarlyStoppingCriterion):

    def __init__(self, patience, path, init_best_val_acc):
        super().__init__(patience, path, init_best_val_acc)
        self.val_acces = []
        self.best_epoch = 0
        self.count = 0

    def should_stop(self, epoch, val_accuracy, model, optimizer, q):

        self.val_acces.append(val_accuracy)
        if val_accuracy > self.best_val_acc:
            # store the model, optimizer and q of the highest accuracy
            self.best_epoch = epoch
            self.best_val_acc = val_accuracy
            torch.save(model.state_dict(), self.path)
            torch.save(optimizer.state_dict(), self.path + '.opt')
            print(self.path)
            self.count = 0
        else: self.count = self.count + 1
        return (self.count >= self.patience and self.val_acces[-1] <= torch.mean(
            torch.tensor(self.val_acces[-(self.patience + 1):-1]))) or self.count > 50

    def after_stopping_ops(self, model, optimizer):
        # we only need to load q for test
        self.count = 0
        model.load_state_dict(torch.load(self.path))
        optimizer.load_state_dict(torch.load(self.path + '.opt'))

        return torch.load(self.path + 'q.tensor')

    def reset(self):
        self.val_acces = []

class PWEMEarlyStoppingCriterionForE(EarlyStoppingCriterion):

    def __init__(self, patience, path, init_best_val_acc):
        super().__init__(patience, path, init_best_val_acc)
        self.val_acces = []
        self.best_epoch = 0
        self.count = 0

    def should_stop(self, epoch, val_accuracy, q):
        # if evaluate accuracy smaller than mean of previous patience times' evaluate accuracy and the count is bigger than the patience, return true

        self.val_acces.append(val_accuracy)
        if val_accuracy > self.best_val_acc:
            # store the model, optimizer and q of the highest accuracy
            self.best_epoch = epoch
            self.best_val_acc = val_accuracy
            torch.save(q, self.path + 'q.tensor')
            print(self.path)
            self.count = 0
        else: self.count = self.count + 1
        return (self.count >= self.patience and self.val_acces[-1]<= np.mean(self.val_acces[-(self.patience + 1):-1])) or self.count>50


    def after_stopping_ops(self):
        # we only need to load q for test
        self.count = 0
        return torch.load(self.path + 'q.tensor')

    def reset(self):
        self.val_acces = []


