# https://medium.com/@garraypierce/train-the-first-gnn-model-for-cora-data-with-pytorch-58a7f3706183
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid, Coauthor, HeterophilousGraphDataset
import wandb


import argparse




def get_masks(length_y, train_fraction):
    # Create an array of indices and shuffle it
    indices = np.arange(length_y)
    np.random.shuffle(indices)

    # Determine the split point
    split_point = int(length_y * train_fraction)

    # Create train and test indices
    train_indices = indices[:split_point]
    test_indices = indices[split_point:]

    # Initialize masks with False
    train_mask = np.zeros(length_y, dtype=bool)
    test_mask = np.zeros(length_y, dtype=bool)

    # Set the appropriate indices to True
    train_mask[train_indices] = True
    test_mask[test_indices] = True

    return train_mask, test_mask

def main():



    argparser = argparse.ArgumentParser(
            "GCN on Torch Datasets",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
    argparser.add_argument(
            "--lr", type=float, default=0.01, help="learning rate"
        )
    argparser.add_argument(
            "--weight_decay", type=float, default=0., help="weight decay"
        )
    argparser.add_argument(
        "--n_layers", type=int, default=3, help="number of layers"
    )
    argparser.add_argument(
        "--n_epochs", type=int, default=1000, help="number of layers"
    )
    argparser.add_argument(
        "--run", type=int, default=0, help="number of layers"
    )
    argparser.add_argument(
        "--n_hidden", type=int, default=256, help="number of hidden units"
    )
    argparser.add_argument(
            "--dropout", type=float, default=0.5, help="dropout rate"
        )
    argparser.add_argument("--dataset", type=str, default="ogbn-arxiv", help="Dataset.")
    argparser.add_argument("--n_partitions", type=int, default=2, help="Number of Partitions.")

    args = argparser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load Cora dataset
    if args.dataset == "Cora":
        dataset = Planetoid(root="./data", name='Cora')
        data = dataset[0].to(device)


    elif args.dataset == "CiteSeer":
        dataset = Planetoid(root="./data", name='CiteSeer')
        data = dataset[0].to(device)


    elif args.dataset == "PubMed":
        dataset = Planetoid(root="./data", name='PubMed')
        data = dataset[0].to(device)


    elif args.dataset == "CS":
        dataset = Coauthor(root="./data", name='CS')
        data = dataset[0].to(device)
        data.train_mask, data.test_mask = get_masks(data.y.shape[0], 0.9)

    elif args.dataset == "Physics":
        dataset = Coauthor(root="./data", name='Physics')
        data = dataset[0].to(device)
        data.train_mask, data.test_mask = get_masks(data.y.shape[0], 0.9)

    elif args.dataset == "Roman":
        dataset = HeterophilousGraphDataset(root="./data", name = "Roman-empire")
        data = dataset[0].to(device)
        data.train_mask, data.test_mask = data.train_mask[:,args.run], data.test_mask[:,args.run]
    
    elif args.dataset == "Amazon":
        dataset = HeterophilousGraphDataset(root="./data", name = "Amazon-ratings")
        data = dataset[0].to(device)
        data.train_mask, data.test_mask = data.train_mask[:,args.run], data.test_mask[:,args.run]
    
    else:
        raise NotImplementedError


    # # Partition the dataset
    try :
        train_idx = np.where(np.array(data.train_mask.cpu()) == 1)[0]
        
    except:
        train_idx = np.where(np.array(data.
                                      train_mask) == 1)[0]

    try :
        n_nodes_train = data.train_mask.sum().item()

    except:
        n_nodes_train = data.train_mask.sum()


    train_idx = train_idx[np.random.permutation(n_nodes_train)[:n_nodes_train//args.n_partitions]]

    binary_mask = [False] * len(data.train_mask)
    for index in train_idx:
        binary_mask[index] = True
    data.train_mask = np.array(binary_mask)


    class GCN(torch.nn.Module):
        def __init__(self, hidden_channels):
            super(GCN, self).__init__()
            self.conv1 = GCNConv(dataset.num_features, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, dataset.num_classes)

        def forward(self, x, edge_index):
            x = F.relu(self.conv1(x, edge_index))
            x = F.dropout(x, training=self.training)
            x = self.conv2(x, edge_index)
            return F.log_softmax(x, dim=1)

    class GCN1(torch.nn.Module):
        def __init__(self, hidden_channels):
            super(GCN1, self).__init__()
            self.conv1 = GCNConv(dataset.num_features, dataset.num_classes)

        def forward(self, x, edge_index):
            x = F.relu(self.conv1(x, edge_index))
            x = F.dropout(x, training=self.training)
            return F.log_softmax(x, dim=1)

    class GCN3(torch.nn.Module):
        def __init__(self, hidden_channels):
            super(GCN3, self).__init__()
            self.conv1 = GCNConv(dataset.num_features, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, hidden_channels)
            self.conv3 = GCNConv(hidden_channels, dataset.num_classes)

        def forward(self, x, edge_index):
            x = F.relu(self.conv1(x, edge_index))
            x = F.dropout(x, training=self.training)
            x = F.relu(self.conv2(x, edge_index))
            x = F.dropout(x, training=self.training)
            x = self.conv3(x, edge_index)
            return F.log_softmax(x, dim=1)

    class GCN4(torch.nn.Module):
        def __init__(self, hidden_channels):
            super(GCN4, self).__init__()
            self.conv1 = GCNConv(dataset.num_features, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, hidden_channels)
            self.conv3 = GCNConv(hidden_channels, hidden_channels)
            self.conv4 = GCNConv(hidden_channels, dataset.num_classes)

        def forward(self, x, edge_index):
            # layer 1
            x = F.relu(self.conv1(x, edge_index))
            x = F.dropout(x, training=self.training)
            # layer 2
            x = F.relu(self.conv2(x, edge_index))
            x = F.dropout(x, training=self.training)
            # layer 3
            x = F.relu(self.conv3(x, edge_index))
            x = F.dropout(x, training=self.training)
            # layer 4
            x = self.conv4(x, edge_index)
            return F.log_softmax(x, dim=1)


    if args.n_layers == 1:
        model = GCN1(hidden_channels=args.n_hidden).to(device)
    elif args.n_layers == 2:
        model = GCN(hidden_channels=args.n_hidden).to(device)
    elif args.n_layers == 3:
        model = GCN3(hidden_channels=args.n_hidden).to(device)
    elif args.n_layers == 4:
        model = GCN4(hidden_channels=args.n_hidden).to(device)


    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)


    wandb.init(
        # set the wandb project where this run will be logged
        project="GCN on torch datasets",
        name = f"{args.dataset}_part_{args.n_partitions}_nodes_{len(train_idx)}_lr_{args.lr}_layers_{args.n_layers}_hidden_{args.n_hidden}_epoch_{args.n_epochs}_run_{args.run}",
        # set the name of the run
        # track hyperparameters and run metadata
        config=args
    )

    def train():
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        return loss.item()

    @torch.no_grad()
    def loss_and_acc():
        model.eval()
        out = model(data.x, data.edge_index)
        loss_train = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss_test = F.nll_loss(out[data.test_mask], data.y[data.test_mask])
        pred = out.argmax(dim=1)
        acc_train = (pred[data.train_mask] == data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
        acc_test = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()

        return loss_train.item(), loss_test.item(), acc_train, acc_test

    def test():
        model.eval()
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
        return acc

    for epoch in range(1, args.n_epochs+1):
        loss = train()
        acc = test()
        if epoch % 100 == 0:
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')
    train_loss, test_loss, train_acc, test_acc = loss_and_acc() # loss_and_acc
    wandb.log({"Train Accuracy": train_acc, 
                "Test Accuracy": test_acc,
                "Train Loss": train_loss,
                "Test Loss": test_loss,
                        })

if __name__ == "__main__":
    main()
    wandb.finish()