import numpy as np
import torch


def accuracy(output, labels, return_idx=False):
    pred = output.max(1)[1].type_as(labels)
    correct = pred.eq(labels).double()
    if not return_idx:
        return (correct.sum() / len(labels) * 100.0).item()
    else:
        return (correct.sum() / len(labels) * 100.0).item(), np.where(correct.cpu()==1)

def node_cls_train(model, train_idx, labels, device, optimizer, loss_fn, loss_weight):
    model.train()
    optimizer.zero_grad()
    if loss_weight is not None:
        train_output, contrastive_loss = model.model_forward(train_idx, device, train_flag=1)
        loss_train = loss_fn(train_output, labels[train_idx]) + loss_weight * contrastive_loss
    else:
        train_output = model.model_forward(train_idx, device)
        loss_train = loss_fn(train_output, labels[train_idx])
    acc_train = accuracy(train_output, labels[train_idx])
    loss_train.backward()
    optimizer.step()

    return loss_train.item(), acc_train

def ood_node_cls_train(model, train_idx, labels, device, optimizer, loss_fn, loss_weight, walk_times):
    model.train()
    optimizer.zero_grad()
    acc_trains = []
    if loss_weight is not None:
        train_output, contrastive_loss = model.model_forward(device, train_idx)
        loss_train = loss_fn(train_output, labels[train_idx]) + loss_weight * contrastive_loss
        for _ in range(walk_times-1):
            train_output = model.model_forward(train_idx, device)
            tmp = loss_fn(train_output, labels[train_idx])+ loss_weight * contrastive_loss
            loss_train = torch.cat((loss_train, tmp))
            acc_train = accuracy(train_output, labels[train_idx])
            acc_trains.append(acc_train)
    else:
        #torch.autograd.set_detect_anomaly(True)
        train_output = model.model_forward(device, train_idx)
        loss_train = loss_fn(train_output, labels[train_idx])
        '''
        loss_train.backward()
        optimizer.step()
        acc_train = accuracy(train_output, labels[train_idx])
        return loss_train.item(), acc_train
        '''
        loss_train = loss_train.unsqueeze_(0)
        #print(torch.isnan(train_output).any())
        #print(loss_train)
        for _ in range(walk_times-1):
            train_output = model.model_forward(device, train_idx)
            #print(torch.isnan(train_output).any())
            tmp = loss_fn(train_output, labels[train_idx])
            #print(f"the temperal loss is {tmp}")
            tmp = tmp.unsqueeze_(0)
            loss_train = torch.cat((loss_train, tmp), dim=0)
            acc_train = accuracy(train_output, labels[train_idx])
            acc_trains.append(acc_train)
    acc_trains = np.array(acc_trains)
    loss_train = (1.0 / walk_times)*loss_train.sum()+loss_train.var()
    #print(f"the final loss is {loss_train}")
    loss_train.backward()
    optimizer.step()

    return loss_train.item(), acc_trains.mean()

def unsupervised_node_cls_train_v2(model, device, optimizer, epoch, train_idx, walk_time):
    model.train()
    optimizer.zero_grad()
    loss = model.model_forward(device, train_idx, walk_time)
    loss = (1.0/loss.size(0))*loss.sum()+loss.var()
    loss.backward()
    optimizer.step()
    '''   
    for name, param in model.named_parameters():
        if name=="base_model.gnn1.lins.1.weight":
            print(f"Gradient for {name}: {param.grad}")
    '''    

    return loss.item()

def unsupervised_node_cls_mini_batch_train_v2(model, device, optimizer, epoch, train_idx, train_loader, walk_time):
    model.train()
    loss_avg = 0
    batch_num = 0
    for batch in train_loader:
        optimizer.zero_grad()
        #print(train_idx.size(0))
        #print(batch.size(0))
        #print("m0")
        loss = model.model_forward(device, batch, walk_time)
        #print("m1")
        loss = (1.0/loss.size(0))*loss.sum()+loss.var()
        loss_avg += loss.item()
        #print("m2")
        loss.backward()
        optimizer.step()
        batch_num += 1
        print(batch_num)
        
    loss_avg /= len(train_loader)
    return loss_avg

def unsupervised_node_cls_evaluate(model, device, optimizer,  val_idx, labels):
    model.eval()
    #op1, op2 = model.model_forward(device, epoch)
    #loss = model.base_model.loss(op1, op2)
    #print(device, epoch)
    losses = []
    for idx in val_idx:
        #print(f"The trainset now is {idx}")
        walks, rw = model.model_forward(device, idx)
        loss = model.loss(walks, rw, device)
        print(f"The sampled random walks of node {idx} are {walks}")
        print(f"The labels of the random walk are {labels[walks.long()]}")
        losses.append(loss.item())

    return losses

def unsupervised_node_cls_evaluate_v2(model, device, optimizer,  val_idx, labels):
    model.eval()
    losses = []
    for idx in val_idx:
        if model.base_model.graph_a.neighbors(idx.item()) is None:
            continue
        walks, rw = model.model_forward(device, idx, 1)
        print(f"The sampled random walks of node {idx} are {walks}")
        print(f"The labels of the random walk are {labels[walks.long()]}")
        loss = model.base_model.loss(walks, rw, device)
        losses.append(loss.item())
    return losses

def node_cls_mini_batch_train(model, train_idx, train_loader, labels, device, optimizer, loss_fn, loss_weight):
    model.train()
    correct_num = 0
    loss_train_sum = 0.
    for batch in train_loader:
        if loss_weight is not None:
            train_output, contrastive_loss = model.model_forward(batch, device, train_flag=1)
            loss_train = loss_fn(train_output, labels[batch]) + loss_weight * contrastive_loss
        else:
            train_output = model.model_forward(batch, device)
            loss_train = loss_fn(train_output, labels[batch])
        pred = train_output.max(1)[1].type_as(labels)
        correct_num += pred.eq(labels[batch]).double().sum()
        loss_train_sum += loss_train.item()
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()

    loss_train = loss_train_sum / len(train_loader)
    acc_train = correct_num / len(train_idx)

    return loss_train, 100*acc_train.item()

def node_cls_evaluate(model, val_idx, test_idx, labels, device, postpro):
    model.eval()
    val_output = model.model_forward(idx=val_idx, device=device)
    test_output = model.model_forward(idx=test_idx, device=device)
    acc_val = accuracy(val_output, labels[val_idx])
    if postpro:
        acc_test, idx = accuracy(test_output, labels[test_idx], postpro)
        return acc_val, acc_test, test_idx[idx]
    else:
        acc_test = accuracy(test_output, labels[test_idx], postpro)
        return acc_val, acc_test

def node_cls_mini_batch_evaluate(model, val_idx, val_loader, test_idx, test_loader, labels, device):
    model.eval()
    correct_num_val, correct_num_test = 0, 0
    for batch in val_loader:
        val_output = model.model_forward(batch, device)
        pred = val_output.max(1)[1].type_as(labels)
        correct_num_val += pred.eq(labels[batch]).double().sum()
    acc_val = correct_num_val / len(val_idx)

    for batch in test_loader:
        test_output = model.model_forward(batch, device)
        pred = test_output.max(1)[1].type_as(labels)
        correct_num_test += pred.eq(labels[batch]).double().sum()
    acc_test = correct_num_test / len(test_idx)

    return acc_val.item(), 100*acc_test.item()