import torch
import torch.nn.functional as F    

from torch_geometric.loader import NeighborSampler
from torch_geometric.utils import accuracy, f1_score
from math import floor
import pdb


def train(model, data, train_idx, optimizer, reference_points=None):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, edge_index=data.edge_index, reference_points=reference_points, test=False)
    loss_xe = F.nll_loss(out[train_idx], data.y.squeeze(1)[train_idx])
    loss = loss_xe
    loss.backward()
    # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
    optimizer.step()

    return loss.item()

@torch.no_grad()
def test(model, data, split_idx, reference_points=None, save_att=False):
    model.eval()

    if save_att:
        out, attmap = model(data.x, edge_index=data.edge_index, reference_points=reference_points, test=True)
        torch.save(attmap, "attention_rand.pt")
    else:
        out = model(data.x, edge_index=data.edge_index, reference_points=reference_points, test=True)
    y_pred = out.argmax(dim=-1, keepdim=True)
    loss_val = F.nll_loss(out[split_idx['valid']].squeeze(1), data.y.squeeze(1)[split_idx['valid']])


    train_acc = accuracy(y_pred[split_idx['train']], data.y[split_idx['train']])
    valid_acc = accuracy(y_pred[split_idx['valid']], data.y[split_idx['valid']])
    test_acc = accuracy(y_pred[split_idx['test']], data.y[split_idx['test']])

    return train_acc, valid_acc, test_acc, loss_val

def train_transformer(model, data, train_loader, optimizer, device):
    model.train()
    loss_list = []
    for batch_size, n_id, adj in train_loader:
        adj = adj.edge_index.to(device)
        optimizer.zero_grad()
        out = model(data.x[n_id], adj)[:batch_size]
        loss = F.nll_loss(out, data.y[n_id[:batch_size]].squeeze(1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        loss_list.append(loss.item())
        
    return sum(loss_list)/len(loss_list)


@torch.no_grad()
def test_transformer(model, data, loaders, device):
    model.eval()
    train_loader, valid_loader, test_loader = loaders
    train_dict, valid_dict, test_dict  = {}, {}, {}
    for i, (batch_size, n_id, adj) in enumerate(train_loader):
        adj = adj.edge_index.to(device)
        out = model(data.x[n_id], adj)[:batch_size]
        y_pred = out.argmax(dim=-1, keepdim=True)
        y_gt = data.y[n_id[:batch_size]]
        if i == 0:
            train_dict['y_true'] = y_gt
            train_dict['y_pred'] = y_pred
        else:
            train_dict['y_true'] = torch.cat((train_dict['y_true'], y_gt), 0)
            train_dict['y_pred'] = torch.cat((train_dict['y_pred'], y_pred), 0)

    for i, (batch_size, n_id, adj) in enumerate(valid_loader):
        adj = adj.edge_index.to(device)
        out = model(data.x[n_id], adj)[:batch_size]
        y_pred = out.argmax(dim=-1, keepdim=True)
        y_gt = data.y[n_id[:batch_size]]
        if i == 0:
            valid_dict['y_true'] = y_gt
            valid_dict['y_pred'] = y_pred
        else:
            valid_dict['y_true'] = torch.cat((valid_dict['y_true'], y_gt), 0)
            valid_dict['y_pred'] = torch.cat((valid_dict['y_pred'], y_pred), 0)

    for i, (batch_size, n_id, adj) in enumerate(test_loader):
        adj = adj.edge_index.to(device)
        out = model(data.x[n_id], adj)[:batch_size]
        y_pred = out.argmax(dim=-1, keepdim=True)
        y_gt = data.y[n_id[:batch_size]]
        if i == 0:
            test_dict['y_true'] = y_gt
            test_dict['y_pred'] = y_pred
        else:
            test_dict['y_true'] = torch.cat((test_dict['y_true'], y_gt), 0)
            test_dict['y_pred'] = torch.cat((test_dict['y_pred'], y_pred), 0)
    # pdb.set_trace()
    train_acc = accuracy(train_dict['y_pred'], train_dict['y_true'])
    valid_acc = accuracy(valid_dict['y_pred'], valid_dict['y_true'])
    test_acc = accuracy(test_dict['y_pred'], test_dict['y_true'])


    return train_acc, valid_acc, test_acc