import os
import argparse

import torch as th
import torch.nn as nn

from dgl import save_graphs
from models import Model

from dgl.data import BAShapeDataset, BACommunityDataset, TreeCycleDataset, TreeGridDataset

def main(args):
    if args.dataset == 'BAShape':
        dataset = BAShapeDataset(seed=0)
    elif args.dataset == 'BACommunity':
        dataset = BACommunityDataset(seed=0)
    elif args.dataset == 'TreeCycle':
        dataset = TreeCycleDataset(seed=0)
    elif args.dataset == 'TreeGrid':
        dataset = TreeGridDataset(seed=0)

    graph = dataset[0]
    labels = graph.ndata['label']
    n_feats = graph.ndata['feat']
    num_classes = dataset.num_classes

    model = Model(n_feats.shape[-1], num_classes)
    loss_fn = nn.CrossEntropyLoss()
    optim = th.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(500):
        model.train()
        # For demo purpose, we train the model on all datapoints
        # In practice, you should train only on the training datapoints
        logits = model(graph, n_feats)
        loss = loss_fn(logits, labels)
        acc = th.sum(logits.argmax(dim=1) == labels).item() / len(labels)

        optim.zero_grad()
        loss.backward()
        optim.step()

        print(f'In Epoch: {epoch}; Acc: {acc}; Loss: {loss.item()}')

    model_stat_dict = model.state_dict()
    model_path = os.path.join('./', f'model_{args.dataset}.pth')
    th.save(model_stat_dict, model_path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Dummy model training')
    parser.add_argument('--dataset', type=str, default='BAShape',
                        choices=['BAShape', 'BACommunity', 'TreeCycle', 'TreeGrid'])
    args = parser.parse_args()
    print(args)

    main(args)
