import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import time
import logging
import dgl
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader

from models.vn import GNN
from utils.utils import prepare_folder, set_seed, evaluate, init_logging


def train(model, data, optimizer, criterion):
    batched_graph, labels = data
    # g = g.cuda()
    # labels = labels#.cuda()
    feat = batched_graph.ndata.pop("attr")
    out = model(batched_graph, feat)
    loss = criterion(out, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item() * len(labels)


@torch.no_grad()
def test(model, data, return_h=False):
    batched_graph, labels = data
    # g = g.cuda()
    # labels = labels#.cuda()
    feat = batched_graph.ndata.pop("attr")
    if return_h:
        out, h = model(batched_graph, feat, return_h)
        return out.softmax(dim=-1), labels, h.cpu()
    out = model(batched_graph, feat)
    return out.softmax(dim=-1), labels


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="MUTAG")
    parser.add_argument("--ratio", type=float, default=0.6)
    parser.add_argument("--model", type=str, default="EGT")
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--hiddens", type=int, default=64)
    parser.add_argument("--layers", type=int, default=4)
    parser.add_argument("--heads", type=int, default=4)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--lr", type=float, default=0.005)
    parser.add_argument("--seed", type=int, default=3407)
    
    args = parser.parse_args()
    model_dir = prepare_folder(args.dataset, args.model, args.ratio)
    init_logging(logging.getLogger(), model_dir)
    logging.info(args)

    device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)
    logging.info("model_dir: " + model_dir)
    set_seed(args.seed)

    dataset = GINDataset(name=args.dataset, self_loop=True)
    in_channels = dataset.dim_nfeats
    edge_channels = None
    nlabels = dataset.gclasses

    data = dataset
    data.graphs = [_.to(device) for _ in data.graphs]
    data.labels = [_.to(device) for _ in data.labels]

    f1, gmean, acc = 0, 0, 0
    index = torch.randperm(len(data))
    for i in range(5):
        train_index = torch.cat((index[i * int(len(index) * 0.2):].reshape(-1), index[:(i + 1) * int(len(index) * 0.2)].reshape(-1)))
        test_index = index[i * int(len(index) * 0.2):(i + 1) * int(len(index) * 0.2)]

        train_dataloader = GraphDataLoader(data, sampler=SubsetRandomSampler(train_index.cuda()), batch_size=args.batch_size)
        test_dataloader = GraphDataLoader(data, sampler=SubsetRandomSampler(test_index.cuda()), batch_size=args.batch_size)

        model = GNN(in_channels, args.hiddens, nlabels, args.model, args.heads, args.layers, args.dropout).to(device)

        logging.info(f"Model {args.model} initialized")
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
        criterion = nn.CrossEntropyLoss()
        best_acc = 0.0

        for epoch in range(1, args.epochs + 1):
            cur_time = time.time()
            loss = 0
            model.train()
            for step, batch in enumerate(train_dataloader, 1):
                loss += train(model, batch, optimizer, criterion)
                # lr_scheduler.step()
                if step % 100 == 0:
                    logging.info(f"Epoch {epoch:02d}, Step {step:02d}, Loss: {loss / step / args.batch_size:.4f}")
            loss /= len(train_index)
            scheduler.step()

            model.eval()
            pred_ys, true_ys = [], []
            for batch in test_dataloader:
                pred_y, true_y = test(model, batch)
                pred_ys.append(pred_y)
                true_ys.append(true_y)
            pred_ys = torch.cat(pred_ys)
            true_ys = torch.cat(true_ys)
            valid_result = evaluate(true_ys, pred_ys)

            if valid_result > best_acc:
                best_acc = valid_result
                # torch.save({
                #     'epoch': epoch,
                #     'model': model.state_dict(),
                #     # 'optimizer': optimizer.state_dict()
                # }, os.path.join(model_dir, 'model.bin'))
                logging.info(
                    f"Epoch: {epoch:02d}, "
                    f"Loss: {loss:.4f}, "
                    f"Valid: {valid_result:.2%}, "
                    f"Best: {best_acc:.4%}, "
                    f"Time: {time.time() - cur_time:.2f}s"
                )

        acc += best_acc / 5

    print(
        f"Test acc: {acc:.4%}"
    )

if __name__ == "__main__":
    main()
