import argparse
import os.path as osp
import wandb

import torch
import torch.nn.functional as F

from torch_geometric.datasets import Entities
from models.encoders import CompGCN, RGCN

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, choices=["AIFB", "MUTAG", "BGS", "AM"])
parser.add_argument("--compgcn", action="store_true")
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--lr", type=float, default=0.0003)
parser.add_argument("--drop", type=float, default=0.0)
parser.add_argument("--epochs", type=int, default=2000)
parser.add_argument("--dim", type=int, default=2)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--wd", type=float, default=0.0005)
parser.add_argument("--short_cut", action="store_true")
parser.add_argument("--msg_func", type=str, default="distmult")
parser.add_argument("--aggr_func", type=str, default="add", choices=["add", "mean", "max", "min", "pna"])
parser.add_argument("--layer_norm", action="store_true")
parser.add_argument("--compgcn_no_dir", action="store_true")
parser.add_argument("--compgcn_no_relupd", action="store_true")
parser.add_argument("--rgcn_fast", action="store_true")
parser.add_argument("--cont_feature", action="store_true")
parser.add_argument("--no_norm", action="store_true")
parser.add_argument("--no_loops", action="store_true")
parser.add_argument("--freeze_gnn", action="store_true")
parser.add_argument("--stripped", action="store_true")
parser.add_argument("--unique_weights", action="store_true")
parser.add_argument("--drop_composition", action="store_true")
parser.add_argument("--drop_msg_weight", action="store_true")
parser.add_argument("--drop_bias", action="store_true")
parser.add_argument("--mod_rgcn", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

# seeding everything
torch.manual_seed(args.seed)

# Trade memory consumption for faster computation.
# if args.dataset in ['AIFB', 'MUTAG']:
#     RGCNConv = FastRGCNConv


def main():

    path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data", "Entities")
    dataset = Entities(path, args.dataset)
    data = dataset[0]


    if not args.cont_feature:
        data.x = torch.zeros((data.num_nodes, args.dim))
        data.x[:, 0] = 1.0
    else:
        feature = torch.empty((1, args.dim))
        torch.nn.init.xavier_uniform_(feature)
        data.x = torch.repeat_interleave(feature, data.num_nodes, dim=0)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu") if args.dataset == "AM" else device

    if args.compgcn:
        model = CompGCN(
            dims=[args.dim] * args.num_layers,
            num_relations=dataset.num_relations,
            num_classes=dataset.num_classes,
            message_func=args.msg_func,
            aggregate_func=args.aggr_func,
            layer_norm=args.layer_norm,
            short_cut=args.short_cut,
            use_dir_weight=not args.compgcn_no_dir,
            use_rel_update=not args.compgcn_no_relupd,
            use_norm=not args.no_norm,
            freeze_gnn=args.freeze_gnn,
            stripped=args.stripped,
            unique_weights=args.unique_weights,
            drop_comp=args.drop_composition,
            drop_msg_weight=args.drop_msg_weight,
        )
    else:
        model = RGCN(
            dims=[args.dim] * args.num_layers,
            num_relations=dataset.num_relations,
            num_classes=dataset.num_classes,
            dropout=args.drop,
            short_cut=args.short_cut,
            fast=args.rgcn_fast,
            aggr=args.aggr_func,
            freeze_gnn=args.freeze_gnn,
            drop_bias=args.drop_bias,
            mod=args.mod_rgcn,
        )

    model, data = model.to(device), data.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    num_params = sum(p.numel() for p in model.parameters())
    args.num_params = num_params
    print(f"Number of parameters: {num_params}")
    print(f"Named params: ")
    print("\n".join([f'{name} : {param.shape}' for name, param in model.named_parameters()]))

    print(f"Training nodes: {len(data.train_idx)}")
    print(f"Test nodes: {len(data.test_idx)}")

    def train():
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_type)
        loss = F.nll_loss(out[data.train_idx], data.train_y)
        loss.backward()
        optimizer.step()
        return loss.item()


    @torch.no_grad()
    def test():
        model.eval()
        out = model(data.x, data.edge_index, data.edge_type)
        pred = out.argmax(dim=-1)
        train_acc = pred[data.train_idx].eq(data.train_y).to(torch.float).mean()
        test_acc = pred[data.test_idx].eq(data.test_y).to(torch.float).mean()
        return train_acc.item(), test_acc.item()


    max_train, max_test = 0.0, 0.0

    if args.wandb:
        run = wandb.init(entity="YOUR_ENTITY", project="YOUR_PROJECT")
        wandb.config.update(vars(args))

    for epoch in range(1, args.epochs):
        loss = train()
        train_acc, test_acc = test()
        max_train = max(max_train, train_acc)
        max_test = max(max_test, test_acc)
        print(
            f"Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} "
            f"Test: {test_acc:.4f}"
        )
        if args.wandb:
            wandb.log({"loss": loss, "train_acc": train_acc, "test_acc": test_acc})


    print(f"Max train: {max_train:.4f}, max test: {max_test:.4f}")
    if args.epochs == 0:
        train_acc, test_acc = test()
        print(f"Train: {train_acc:.4f} Test: {test_acc:.4f}")

if __name__ == "__main__":
    main()
