import argparse
import os
import torch
import torch.nn as nn
import time
import logging
import dgl
from dgl.data import PubmedGraphDataset, CoraGraphDataset, AmazonCoBuyComputerDataset, AmazonCoBuyPhotoDataset, CoauthorCSDataset, CoauthorPhysicsDataset, FlickrDataset
from dgl.dataloading import MultiLayerFullNeighborSampler, DataLoader

from models.egt import EGT
from utils.utils import prepare_folder, set_seed, evaluate, init_logging, add_edge_noise


def train(model, data, optimizer, criterion):
    input_nodes, output_nodes, blocks = data
    # blocks = [b.cuda() for b in blocks]
    x = blocks[0].srcdata['feat']
    y = blocks[-1].dstdata['label']
    out = model(blocks, x)
    loss = criterion(out, y)
    optimizer.zero_grad()
    loss.backward()
    # nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    optimizer.step()
    return loss.item() * len(output_nodes)


@torch.no_grad()
def test(model, data, return_h=False):
    input_nodes, output_nodes, blocks = data
    # blocks = [b.cuda() for b in blocks]
    x = blocks[0].srcdata['feat']
    y = blocks[-1].dstdata['label']
    if return_h:
        out, h = model(blocks, x, return_h)
        return out.softmax(dim=-1), y, h.cpu()
    out= model(blocks, x)
    return out.softmax(dim=-1), y


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="flickr", choices=['pubmed', 'cora', 
        'computer', 'photo', 'physics', 'cs', 'flickr'])
    parser.add_argument("--ratio", type=float, default=0.6)
    parser.add_argument("--model", type=str, default="EGT")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--hiddens", type=int, default=64)
    parser.add_argument("--layers", type=int, default=2)
    parser.add_argument("--heads", type=int, default=4)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--batch_size", type=int, default=65536)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--jitter", type=float, default=0)

    args = parser.parse_args()
    model_dir = prepare_folder(args.dataset, args.model, args.ratio)
    init_logging(logging.getLogger(), model_dir)
    logging.info(args)

    device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)
    logging.info("model_dir: " + model_dir)
    set_seed(args.seed)

    if args.dataset == "pubmed":
        graph = PubmedGraphDataset()[0]
    elif args.dataset == "cora":
        graph = CoraGraphDataset()[0]
    elif args.dataset == "computer":
        graph = AmazonCoBuyComputerDataset()[0]
        graph = dgl.add_self_loop(graph)
    elif args.dataset == "photo":
        graph = AmazonCoBuyPhotoDataset()[0]
        graph = dgl.add_self_loop(graph)
    elif args.dataset == "cs":
        graph = CoauthorCSDataset()[0]
    elif args.dataset == "physics":
        graph = CoauthorPhysicsDataset()[0]
    elif args.dataset == "flickr":
        graph = FlickrDataset()[0]
    else:
        raise ValueError('Unsupported dataset!')
    in_channels = graph.ndata['feat'].shape[-1]
    edge_channels = graph.edata['feat'].shape[-1] if 'feat' in graph.edata else None
    nlabels = graph.ndata['label'].max().item() + 1

    if args.jitter > 0:
        graph = add_edge_noise(graph, args.jitter)

    data = graph.to(device)

    if not os.path.exists(f'data/{args.dataset}.pt'):
        index = torch.randperm(data.num_nodes())
        train_index = index[:int(len(index) * 0.6)]
        val_index = index[int(len(index) * 0.6):int(len(index) * 0.8)]
        test_index = index[int(len(index) * 0.8):]
        torch.save({
            'train_index': train_index,
            'val_index': val_index,
            'test_index': test_index
        }, f'data/{args.dataset}.pt')
        logging.info('Index created.')
    else:
        index = torch.load(f'data/{args.dataset}.pt')
        train_index = index['train_index']
        val_index = index['val_index']
        test_index = index['test_index']

    sampler = MultiLayerFullNeighborSampler(args.layers, prefetch_node_feats=['feat'], prefetch_labels=['label'])
    train_dataloader = DataLoader(data, train_index.cuda(), sampler, batch_size=args.batch_size, shuffle=True)
    valid_dataloader = DataLoader(data, val_index.cuda(), sampler, batch_size=args.batch_size, shuffle=False)
    test_dataloader = DataLoader(data, test_index.cuda(), sampler, batch_size=args.batch_size, shuffle=False)

    model = EGT(
        in_channels=in_channels,
        edge_channels=edge_channels,
        hidden_channels=args.hiddens,
        num_class=nlabels,
        num_layers=args.layers,
        num_heads=args.heads,
        dropout=args.dropout
    ).to(device)

    logging.info(f"Model {args.model} initialized")
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    best_auc = 0.0
    for epoch in range(1, args.epochs + 1):
        cur_time = time.time()
        loss = 0
        model.train()
        for step, batch in enumerate(train_dataloader, 1):
            loss += train(model, batch, optimizer, criterion)
            # lr_scheduler.step()
            if step % 10 == 0:
                logging.info(f"Epoch {epoch:02d}, Step {step:02d}, Loss: {loss / step / args.batch_size:.4f}")
        loss /= len(train_index)

        model.eval()
        pred_ys, true_ys = [], []
        for batch in valid_dataloader:
            pred_y, true_y = test(model, batch)
            pred_ys.append(pred_y)
            true_ys.append(true_y)
        pred_ys = torch.cat(pred_ys)
        true_ys = torch.cat(true_ys)
        valid_auc = evaluate(true_ys, pred_ys)

        if valid_auc >= best_auc:
            best_auc = valid_auc
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                # 'optimizer': optimizer.state_dict()
            }, os.path.join(model_dir, 'model.bin'))
            logging.info(
                f"Epoch: {epoch:02d}, "
                f"Loss: {loss:.4f}, "
                f"Valid: {valid_auc:.2%}, "
                f"Best: {best_auc:.4%}, "
                f"Time: {time.time() - cur_time:.2f}s"
            )

    params = torch.load(os.path.join(model_dir, 'model.bin'))
    logging.info(f"Loading best model at epoch: {params['epoch']:02d}")
    model.load_state_dict(params['model'])
    # optimizer.load_state_dict(params['optimizer'])
    model.eval()
    pred_ys, true_ys, hs = [], [], []
    for batch in test_dataloader:
        pred_y, true_y, h = test(model, batch, True)
        pred_ys.append(pred_y)
        true_ys.append(true_y)
        hs.append(h)
    pred_ys = torch.cat(pred_ys)
    true_ys = torch.cat(true_ys)
    hs = torch.cat(hs)
    test_result = evaluate(true_ys, pred_ys, all=True)
    print(
        f"Test auroc: {test_result['auroc']:.4%}, "
        f"f1: {test_result['f1']:.4%}, "
        f"gmean: {test_result['gmean']:.4%}, "
        f"acc: {test_result['acc']:.4%}"
    )

if __name__ == "__main__":
    main()
