# The training codes of the dummy model


import argparse
from datetime import datetime
from pathlib import Path

import dgl
import torch
import torch.nn as nn
from tqdm import tqdm
from dgl import save_graphs

from ex.dgl.models import GraphConvModel
from ex.dgl.gengraph import gen_syn1, gen_syn2, gen_syn3, gen_syn4, gen_syn5


def main(args):
    args.output_dir.mkdir(parents=True, exist_ok=True)

    # load dataset
    if args.dataset == 'syn1':
        g, labels, name = gen_syn1()
    elif args.dataset == 'syn2':
        g, labels, name = gen_syn2()
    elif args.dataset == 'syn3':
        g, labels, name = gen_syn3()
    elif args.dataset == 'syn4':
        g, labels, name = gen_syn4()
    elif args.dataset == 'syn5':
        g, labels, name = gen_syn5()
    else:
        raise NotImplementedError

    # Transform to dgl graph.
    graph = dgl.from_networkx(g)
    labels = torch.tensor(labels, dtype=torch.long)
    graph.ndata['label'] = labels
    graph.ndata['feat'] = torch.randn(graph.number_of_nodes(), args.feat_dim)
    hid_dim = torch.tensor(args.hidden_dim, dtype=torch.long)
    label_dict = {'hid_dim': hid_dim}

    bin_file = args.output_dir / f'{args.dataset}.bin'
    # save graph for later use
    save_graphs(filename=str(bin_file), g_list=[graph], labels=label_dict)

    num_classes = max(graph.ndata['label']).item() + 1
    n_feats = graph.ndata['feat']

    # create model
    dummy_model = GraphConvModel(args.feat_dim, args.hidden_dim, num_classes)
    loss_fn = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(
        dummy_model.parameters(), lr=args.lr, weight_decay=args.wd
    )

    bar = tqdm(range(args.epochs))

    # train and output
    for epoch in bar:
        dummy_model.train()

        logits = dummy_model(graph, n_feats)
        loss = loss_fn(logits, labels)
        acc = torch.sum(logits.argmax(dim=1) == labels).item() / len(labels)

        optim.zero_grad()
        loss.backward()
        optim.step()

        bar.set_description(f'[{epoch:05d}] A:{acc:.4f},L:{loss:.4f}')

    # save model
    data = dict(
        model_state=dummy_model.state_dict(),
        dataset=args.dataset,
        feat_dim=args.feat_dim,
        hidden_dim=args.hidden_dim,
    )
    timestamp = datetime.now().strftime('%y%m%d%H%M%S')
    model_path = args.output_dir / f'{args.dataset}_{timestamp}.pt'
    torch.save(data, model_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--dataset',
        type=str,
        default='syn1',
        choices=['syn1', 'syn2', 'syn3', 'syn4', 'syn5'],
    )
    parser.add_argument('--feat-dim', type=int, default=10)
    parser.add_argument('--hidden-dim', type=int, default=20)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--output-dir', type=Path, default='./output')
    parser.add_argument('--wd', type=float, default=0.0)

    args = parser.parse_args()

    main(args)
