import dgl
import argparse
import torch as th
import torch.optim as optim
import torch.nn.functional as F
from dataloader import GASDataset
from model_sampling import GAS
from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score


def evaluate(model, loss_fn, dataloader, device='cpu'):
    loss = 0
    f1 = 0
    auc = 0
    rap = 0
    num_blocks = 0
    for input_nodes, edge_subgraph, blocks in dataloader:
        blocks = [b.to(device) for b in blocks]
        edge_subgraph = edge_subgraph.to(device)
        u_feat = blocks[0].srcdata['feat']['u']
        v_feat = blocks[0].srcdata['feat']['v']
        f_feat = blocks[0].edges['forward'].data['feat']
        b_feat = blocks[0].edges['backward'].data['feat']
        labels = edge_subgraph.edges['forward'].data['label'].long()
        logits = model(edge_subgraph, blocks, f_feat, b_feat, u_feat, v_feat)

        loss += loss_fn(logits, labels).item()
        f1 += f1_score(labels.cpu(), logits.argmax(dim=1).cpu())
        auc += roc_auc_score(labels.cpu(), logits[:, 1].detach().cpu())
        pre, re, _ = precision_recall_curve(labels.cpu(), logits[:, 1].detach().cpu())
        rap += re[pre > args.precision].max()
        num_blocks += 1

    return rap / num_blocks, f1 / num_blocks, auc / num_blocks, loss / num_blocks


def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load dataset
    dataset = GASDataset(args.dataset)
    graph = dataset[0]

    # generate mini-batch only for forward edges
    sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10])
    tr_eid_dict = {}
    val_eid_dict = {}
    test_eid_dict = {}
    tr_eid_dict['forward'] = graph.edges['forward'].data["train_mask"].nonzero().squeeze()
    val_eid_dict['forward'] = graph.edges['forward'].data["val_mask"].nonzero().squeeze()
    test_eid_dict['forward'] = graph.edges['forward'].data["test_mask"].nonzero().squeeze()

    tr_loader = dgl.dataloading.EdgeDataLoader(graph,
                                               tr_eid_dict,
                                               sampler,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               drop_last=False,
                                               num_workers=args.num_workers)
    val_loader = dgl.dataloading.EdgeDataLoader(graph,
                                                val_eid_dict,
                                                sampler,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                drop_last=False,
                                                num_workers=args.num_workers)
    test_loader = dgl.dataloading.EdgeDataLoader(graph,
                                                 test_eid_dict,
                                                 sampler,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 drop_last=False,
                                                 num_workers=args.num_workers)

    # check cuda
    if args.gpu >= 0 and th.cuda.is_available():
        device = 'cuda:{}'.format(args.gpu)
    else:
        device = 'cpu'

    # binary classification
    num_classes = dataset.num_classes

    # Extract node features
    e_feats = graph.edges['forward'].data['feat'].shape[-1]
    u_feats = graph.nodes['u'].data['feat'].shape[-1]
    v_feats = graph.nodes['v'].data['feat'].shape[-1]

    # Step 2: Create model =================================================================== #
    model = GAS(e_in_dim=e_feats,
                u_in_dim=u_feats,
                v_in_dim=v_feats,
                e_hid_dim=args.e_hid_dim,
                u_hid_dim=args.u_hid_dim,
                v_hid_dim=args.v_hid_dim,
                out_dim=num_classes,
                num_layers=args.num_layers,
                dropout=args.dropout,
                activation=F.relu)

    model = model.to(device)

    # Step 3: Create training components ===================================================== #
    loss_fn = th.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # Step 4: training epochs =============================================================== #
    for epoch in range(args.max_epoch):
        model.train()
        tr_loss = 0
        tr_f1 = 0
        tr_auc = 0
        tr_rap = 0
        tr_blocks = 0
        for input_nodes, edge_subgraph, blocks in tr_loader:
            blocks = [b.to(device) for b in blocks]
            edge_subgraph = edge_subgraph.to(device)
            u_feat = blocks[0].srcdata['feat']['u']
            v_feat = blocks[0].srcdata['feat']['v']
            f_feat = blocks[0].edges['forward'].data['feat']
            b_feat = blocks[0].edges['backward'].data['feat']
            labels = edge_subgraph.edges['forward'].data['label'].long()
            logits = model(edge_subgraph, blocks, f_feat, b_feat, u_feat, v_feat)

            # compute loss
            batch_loss = loss_fn(logits, labels)
            tr_loss += batch_loss.item()
            tr_f1 += f1_score(labels.cpu(), logits.argmax(dim=1).cpu())
            tr_auc += roc_auc_score(labels.cpu(), logits[:, 1].detach().cpu())
            tr_pre, tr_re, _ = precision_recall_curve(labels.cpu(), logits[:, 1].detach().cpu())
            tr_rap += tr_re[tr_pre > args.precision].max()
            tr_blocks += 1

            # backward
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

        # validation
        model.eval()
        val_rap, val_f1, val_auc, val_loss = evaluate(model, loss_fn, val_loader, device)

        # Print out performance
        print("In epoch {}, Train R@P: {:.4f} | Train F1: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; "
              "Valid R@P: {:.4f} | Valid F1: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}".
              format(epoch, tr_rap / tr_blocks, tr_f1 / tr_blocks, tr_auc / tr_blocks , tr_loss / tr_blocks,
                     val_rap, val_f1, val_auc, val_loss))

    # Test with mini batch after all epoch
    model.eval()
    test_rap, test_f1, test_auc, test_loss = evaluate(model, loss_fn, test_loader, device)
    print("Test R@P: {:.4f} | Test F1: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}".
          format(test_rap, test_f1, test_auc, test_loss))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model')
    parser.add_argument("--dataset", type=str, default="pol", help="'pol', or 'gos'")
    parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.")
    parser.add_argument("--e_hid_dim", type=int, default=128, help="Hidden layer dimension for edges")
    parser.add_argument("--u_hid_dim", type=int, default=128, help="Hidden layer dimension for source nodes")
    parser.add_argument("--v_hid_dim", type=int, default=128, help="Hidden layer dimension for destination nodes")
    parser.add_argument("--num_layers", type=int, default=2, help="Number of GCN layers")
    parser.add_argument("--max_epoch", type=int, default=100, help="The max number of epochs. Default: 100")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate. Default: 1e-3")
    parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate. Default: 0.0")
    parser.add_argument("--batch_size", type=int, default=64, help="Size of mini-batches. Default: 64")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of node dataloader")
    parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight Decay. Default: 0.0005")
    parser.add_argument("--precision", type=float, default=0.9, help="The value p in recall@p precision. Default: 0.9")

    args = parser.parse_args()
    print(args)
    main(args)
