import argparse

import dgl
import dgl.nn as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from config import get_args
from model import Model
from datasets import Dataset
from tqdm import tqdm
import numpy as np

class GCN(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # two-layer GCN
        self.layers.append(
            dglnn.GraphConv(in_size, hid_size, activation=F.relu)
        )
        self.layers.append(dglnn.GraphConv(hid_size, out_size))
        self.dropout = nn.Dropout(0.5)

    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        return h


def evaluate(g, features, labels, mask, model):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


def train(run, g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask = masks[0]
    val_mask = masks[1]
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)

    # training loop
    with tqdm(total=args.num_steps, desc=f'Run {run}', disable=args.verbose) as progress_bar:
        for epoch in range(1000):
            model.train()
            logits = model(g, features)
            loss = loss_fcn(logits[train_mask], labels[train_mask])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = evaluate(g, features, labels, val_mask, model)
            progress_bar.update()
            progress_bar.set_postfix({"Epoch":"{:05d}".format(epoch),
                                      "Loss:":"{:.4f}".format(loss.item()),
                                      "Accuracy":"{:.4f}".format(acc)})

if __name__ == "__main__":
    args = get_args()

    torch.manual_seed(args.seed)

    dataset = Dataset(name=args.dataset,
                      model_name=args.model,
                      add_self_loops=True,
                      device=args.device,
                      use_sgc_features=args.use_sgc_features,
                      use_identity_features=args.use_identity_features,
                      use_adjacency_features=args.use_adjacency_features,
                      do_not_use_original_features=args.do_not_use_original_features,
                      seed=args.seed,
                      prefer_feat=args.prefer_feat,
                      rewrite_basis=args.rewrite_basis,
                      rewrite_construct=args.rewrite_construct,
                      rewrite_construct_param=args.rewrite_construct_param,
                      )
    g = dataset.graph
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    g = g.int().to(device)
    features = dataset.node_features
    labels = dataset.labels

    # create GCN model
    in_size = dataset.num_node_features
    out_size = dataset.num_class
    model = GCN(in_size, args.hidden_dim , out_size).to(device)

    acc_list = []
    for run in range(args.num_runs):
        # model training
        print("Training...")
        masks = dataset.train_idx_list[run], dataset.val_idx_list[run], dataset.test_idx_list[run]
        train(run, g, features, labels, masks, model)

        # test the model
        print("Testing...")
        acc = evaluate(g, features, labels, masks[2], model)
        acc_list.append(acc)
        print("Test accuracy {:.4f}".format(acc))
    
    print("Average Test accuracy {:.4f}".format(np.array(acc_list).mean()))
