import torch
import torch.nn.functional as F
from sklearn import metrics

def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def test(model, features, adj, labels, idx_test):
    model.eval()
    output = model(features, adj)
    pred_labels=torch.argmax(output,axis=1)
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    
    #acc_test = metrics.accuracy_score(labels[idx_test].cpu().detach().numpy(), pred_labels[idx_test].cpu().detach().numpy())
    #f1_test=metrics.f1_score(labels[idx_test].cpu().detach().numpy(), pred_labels[idx_test].cpu().detach().numpy(),average='weighted')
    #auc_test=metrics.roc_auc_score(one_hot(labels[idx_test].cpu().detach().numpy()), output[idx_test].cpu().detach().numpy(),multi_class='ovr',average='weighted')
    
    return loss_test.item(), acc_test.item()#, f1_test, auc_test



def train(epoch, model, optimizer, features, adj, labels, idx_train):  #Centralized or new FL
    
    model.train()
    optimizer.zero_grad()
    
    output = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    #acc_train = metrics.accuracy_score(pred_labels[idx_train].cpu().detach().numpy(), labels[idx_train].cpu().detach().numpy())
    loss_train.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    #acc_val = metrics.accuracy_score(pred_labels[idx_val].cpu().detach().numpy(),labels[idx_val].cpu().detach().numpy())
    #print("epoch", epoch, 
    #      "train", loss_train.item(), acc_train.item())
    return loss_train.item(), acc_train.item()


def Lhop_Block_matrix_train(epoch, model, optimizer, features, adj, labels, communicate_index, in_com_train_data_index):
    model.train()
    optimizer.zero_grad()
    #print(features.shape)
    
    output = model(features[communicate_index], adj[communicate_index][:,communicate_index])
   
    
    loss_train = F.nll_loss(output[in_com_train_data_index], labels[communicate_index][in_com_train_data_index])
    
    
    acc_train = accuracy(output[in_com_train_data_index], labels[communicate_index][in_com_train_data_index])
    

    loss_train.backward()
    optimizer.step()
    optimizer.zero_grad()
    #print(loss_train,acc_train)
    return loss_train.item(), acc_train.item()

def FedSage_train(epoch, model, optimizer, features, adj, labels, communicate_index, in_com_train_data_index):
    model.train()
    optimizer.zero_grad()
    #print(features.shape)
    
    
    output = model(features, adj[communicate_index][:,communicate_index])
   
    loss_train = F.nll_loss(output[in_com_train_data_index], labels[communicate_index][in_com_train_data_index])
    
    
    acc_train = accuracy(output[in_com_train_data_index], labels[communicate_index][in_com_train_data_index])
    

    loss_train.backward()
    optimizer.step()
    optimizer.zero_grad()
    #print(loss_train,acc_train)
    return loss_train.item(), acc_train.item()
