import torch
import torch.nn.functional as F

from torch_geometric.nn import GCN, GraphSAGE
import torch_geometric.transforms as T
from torch_geometric.utils import to_dense_adj, index_to_mask, mask_to_index

from torch_geometric.datasets import (
    Planetoid,
    Coauthor,
    Amazon,
    HeterophilousGraphDataset,
)
from ogb.nodeproppred import PygNodePropPredDataset

from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import adjusted_mutual_info_score

import pandas as pd

pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", 100)
pd.set_option("display.width", 1000)

from IPython.display import clear_output
from tqdm.autonotebook import tqdm

from contextlib import nullcontext

import os
import os.path as osp
import sys
import argparse

import glob

sys.path.insert(1, "../Code/src/")
from neural_sbm.pool import NeuralSBMPool, AutoPool
from neural_sbm.transformer import Transformer
from neural_sbm.mlp import MLP
from neural_sbm.regularization import adamw_regularization


def supervised_clustering(
    save_path="./",
    slurm_id=0,
    experiment_id=0,
    run_id=0,
    dataset="Cora",
    split="public",
    max_clusters=7,
    num_train_nodes_per_class=20,
    num_val_nodes=500,
    pooling=None,
    reconstruct_attributes=False,
    gnn="GCN",
    xnn=None,
    snn=None,
    bnn=None,
    graph_reconstruction_method="sample",
    regularization=None,
    supervised=True,
    optimization="AdamW",
    learning_rate=0.001,
    weight_decay=0.01,
    stopping_criterion="mcc",
):
    result = {
        "gpu": os.popen("nvidia-smi --query-gpu=gpu_name --format=csv")
        .read()
        .split("\n")[1],
        "slurm_id": slurm_id,
        "experiment_id": experiment_id,
        "run_id": run_id,
        "dataset": dataset,
        "split": split,
        "max_clusters": max_clusters,
        "num_train_nodes_per_class": num_train_nodes_per_class,
        "num_val_nodes": num_val_nodes,
        "pooling": pooling,
        "reconstruct_attributes": reconstruct_attributes,
        "gnn": gnn,
        "xnn": xnn,
        "snn": snn,
        "bnn": bnn,
        "graph_reconstruction_method": graph_reconstruction_method,
        "regularization": regularization,
        "supervised": supervised,
        "stopping_criterion": stopping_criterion,
        "optimization": optimization,
    }

    # print gpu
    print("\nGPU:", result["gpu"], "\n")

    # path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data", dataset)
    root = save_path

    transform = T.Compose(
        [
            T.ToUndirected(),
            T.NormalizeFeatures(),
            # T.GCNNorm(add_self_loops=False),
        ]
        + (
            [
                T.RandomNodeSplit(
                    split="test_rest",
                    num_train_per_class=num_train_nodes_per_class,
                    num_val=num_val_nodes,
                ),
            ]
            if dataset in ["CS", "Physics", "Computers", "Photo"]
            else []
        )
    )

    if dataset in ["CS", "Physics"]:
        dataset = Coauthor(
            root=root,
            name=dataset,
            transform=transform,
        )
        data = dataset[0]

    if dataset in ["Computers", "Photo"]:
        dataset = Amazon(
            root=root,
            name=dataset,
            transform=transform,
        )
        data = dataset[0]

    if "ogbn" in dataset:
        dataset = PygNodePropPredDataset(
            root=root,
            name=dataset,
            transform=transform,
        )
        data = dataset[0]

        split_idx = dataset.get_idx_split()
        train_idx, valid_idx, test_idx = (
            split_idx["train"],
            split_idx["valid"],
            split_idx["test"],
        )
        data.train_mask = index_to_mask(train_idx, size=data.num_nodes)
        data.val_mask = index_to_mask(valid_idx, size=data.num_nodes)
        data.test_mask = index_to_mask(test_idx, size=data.num_nodes)

        data.y = data.y.squeeze(dim=-1)

    if dataset in [
        "Roman-empire",
        "Amazon-ratings",
        "Minesweeper",
        "Tolokers",
        "Questions",
    ]:
        dataset = HeterophilousGraphDataset(
            root=root,
            name=dataset,
            transform=transform,
        )
        data = dataset[0]

        data.train_mask = data.train_mask.T[run_id]
        data.val_mask = data.val_mask.T[run_id]
        data.test_mask = data.test_mask.T[run_id]

    if ("ogbn" in dataset) or (
        dataset
        in [
            "Roman-empire",
            "Amazon-ratings",
            "Minesweeper",
            "Tolokers",
            "Questions",
        ]
    ):
        if split == "sparse":
            train_idx = []
            classes = data.y.unique()
            for c in classes:
                idx = torch.where(data.y[data.train_mask] == c)[0]
                idx = idx[torch.randperm(idx.size(0))][:num_train_nodes_per_class]
                train_idx.append(idx)
            train_idx = torch.cat(train_idx, dim=0)
            data.train_mask = index_to_mask(train_idx, size=data.num_nodes)

            val_idx = mask_to_index(data.val_mask)
            val_idx = val_idx[torch.randperm(val_idx.size(0))][:num_val_nodes]
            data.val_mask = index_to_mask(val_idx, size=data.num_nodes)

    if dataset in ["Cora", "CiteSeer", "PubMed"]:
        dataset = Planetoid(
            root=root,
            name=dataset,
            transform=transform,
            split=split,
            num_train_per_class=num_train_nodes_per_class,
            num_val=num_val_nodes,
            num_test=1000,
        )
        data = dataset[0]

    if not hasattr(data, "num_classes"):
        data.num_classes = data.y.unique().shape[0]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # AutoPool (contextual)
    if gnn is not None:
        if gnn == "GCN":
            gnn = GCN(
                in_channels=dataset.num_features,
                hidden_channels=256,
                num_layers=3,
                out_channels=max_clusters,
                dropout=0.5,
            )
        elif gnn == "GraphSAGE":
            gnn = GraphSAGE(
                in_channels=dataset.num_features,
                hidden_channels=256,
                num_layers=3,
                out_channels=max_clusters,
                dropout=0.5,
            )
        else:
            raise NotImplementedError

        if xnn == "Transformer":
            xnn = Transformer(
                in_channels=max_clusters,
                out_channels=dataset.num_features,
            )
        if xnn == "MLP":
            xnn = MLP(
                in_channels=max_clusters,
                out_channels=dataset.num_features,
            )

        if bnn == "Transformer":
            bnn = Transformer(
                in_channels=max_clusters,
                out_channels=max_clusters,
            )
        if bnn == "MLP":
            bnn = MLP(
                in_channels=max_clusters,
                out_channels=max_clusters,
            )

        model = AutoPool(
            gnn=gnn,
            xnn=xnn,
            bnn=bnn,
            order="contextual",
            pooling_method=pooling,
            regularization_method=regularization,
            max_clusters=max_clusters,
            num_attributes=dataset.num_features,
            graph_reconstruction_method=graph_reconstruction_method,
            reconstruct_attributes=reconstruct_attributes,
            adj=(
                to_dense_adj(edge_index=data.edge_index, edge_attr=data.edge_weight)[
                    0
                ].to(device)
                if pooling == "MapEqPool"
                else None
            ),
        )

    # NeuralSBMPool
    else:
        if snn == "Transformer":
            snn = Transformer(
                in_channels=dataset.num_features,
                out_channels=max_clusters,
            )
        if snn == "MLP":
            snn = MLP(
                in_channels=dataset.num_features,
                out_channels=max_clusters,
            )

        if bnn == "Transformer":
            bnn = Transformer(
                in_channels=max_clusters,
                out_channels=max_clusters,
            )
        if bnn == "MLP":
            bnn = MLP(
                in_channels=max_clusters,
                out_channels=max_clusters,
            )

        model = NeuralSBMPool(
            nn=snn,
            bnn=bnn,
            in_channels=dataset.num_features,
            max_clusters=max_clusters,
            pooling_method=pooling,
            graph_reconstruction_method=graph_reconstruction_method,
            regularization_method=regularization,
            adj=(
                to_dense_adj(edge_index=data.edge_index, edge_attr=data.edge_weight)[
                    0
                ].to(device)
                if pooling == "MapEqPool"
                else None
            ),
        )

    model = model.to(device)
    data = data.to(device)

    if optimization == "Adam":
        optimizer = torch.optim.Adam(
            [
                dict(params=model.convs.parameters(), weight_decay=0.01),
                dict(params=model.lins.parameters(), weight_decay=5e-4),
            ],
            lr=0.01,
        )
    if optimization == "AdamW":
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=1.0 if regularization == "L2" else weight_decay,
        )

    # print num parameters
    num_params = sum(p.numel() for p in model.parameters())
    # print(f"Number of parameters: {num_params}")

    def train():
        model.train()
        optimizer.zero_grad()

        out, _, _, unsupervised_loss = model(
            x=data.x,
            edge_index=data.edge_index,
            edge_weight=data.edge_weight,
        )

        out, _ = out.split([data.num_classes, max_clusters - data.num_classes], dim=-1)

        # l2 regularization
        if regularization == "L2":
            l2_loss = adamw_regularization(model)
            unsupervised_loss += l2_loss

        if supervised:
            # loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
            supervised_loss = F.cross_entropy(
                out[data.train_mask], data.y[data.train_mask]
            )
        else:
            supervised_loss = 0.0

        loss = supervised_loss + unsupervised_loss

        loss.backward()

        optimizer.step()
        return out.argmax(dim=-1), loss.item()

    @torch.no_grad()
    def eval():
        model.eval()

        out, _, _, _ = model(
            x=data.x,
            edge_index=data.edge_index,
            edge_weight=data.edge_weight,
        )

        out, _ = out.split([data.num_classes, max_clusters - data.num_classes], dim=-1)

        pred, mccs, accs, amis = out.argmax(dim=-1), [], [], []
        for _, mask in data("train_mask", "val_mask", "test_mask"):
            mcc = matthews_corrcoef(
                pred[mask].cpu().numpy(), data.y[mask].cpu().numpy()
            )
            mccs.append(mcc)

            acc = (pred[mask] == data.y[mask]).sum().item() / mask.sum().item()
            accs.append(acc)

            ami = adjusted_mutual_info_score(
                pred[mask].cpu().numpy(), data.y[mask].cpu().numpy()
            )
            amis.append(ami)

        return amis, accs, mccs

    best_loss = float("inf")
    best_train_mcc = best_val_mcc = best_test_mcc = 0.0
    best_train_acc = best_val_acc = best_test_acc = 0.0
    best_train_ami = best_val_ami = best_test_ami = 0.0
    best_pred = torch.randint_like(data.y, 0, max_clusters)
    wait = 0
    patience = 100
    for epoch in tqdm(range(1, 1001), disable=True, ncols=100, leave=False):
        pred, loss = train()

        amis, accs, mccs = eval()
        train_ami, val_ami, test_ami = amis
        train_acc, val_acc, test_acc = accs
        train_mcc, val_mcc, test_mcc = mccs

        new_best = False
        if stopping_criterion == "loss":
            if loss < best_loss:
                new_best = True

        if stopping_criterion == "mcc":
            if abs(val_mcc) > abs(best_val_mcc):
                new_best = True

        if stopping_criterion == "acc":
            if val_acc > best_val_acc:
                new_best = True

        if stopping_criterion == "ami":
            if abs(train_ami) > abs(best_train_ami):
                new_best = True

        if new_best:
            best_loss = loss
            best_train_mcc = train_mcc
            best_val_mcc = val_mcc
            best_test_mcc = test_mcc
            best_train_acc = train_acc
            best_val_acc = val_acc
            best_test_acc = test_acc
            best_train_ami = train_ami
            best_val_ami = val_ami
            best_test_ami = test_ami
            best_pred = pred
            wait = 0
        else:
            wait += 1
            if wait > patience:
                break

        print("Epoch: {:03d}, Loss: {:.4f}".format(epoch, loss))
        sys.stdout.flush()

    # print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

    result["loss"] = best_loss
    result["train_mcc"] = best_train_mcc
    result["val_mcc"] = best_val_mcc
    result["test_mcc"] = best_test_mcc
    result["train_acc"] = best_train_acc
    result["val_acc"] = best_val_acc
    result["test_acc"] = best_test_acc
    result["train_ami"] = best_train_ami
    result["val_ami"] = best_val_ami
    result["test_ami"] = best_test_ami
    result["pred_labels"] = best_pred.tolist()

    results = pd.DataFrame(columns=result.keys())

    pred_labels = result["pred_labels"]
    result["pred_labels"] = None
    result = pd.DataFrame(result, index=[0])
    results = pd.concat([results, result], ignore_index=True)
    results["pred_labels"][results.index[-1]] = pred_labels

    if not osp.exists(save_path):
        os.makedirs(save_path)

    file_name = "hpc-results-" + str(slurm_id) + ".zip"
    results.to_csv(save_path + file_name, index=False, compression="zip")

    return result


def check_gpu():
    print("\nChecking GPU...\n")
    sys.stdout.flush()
    os.system("nvidia-smi")


def check_cpu():
    print("\nChecking CPU...\n")
    sys.stdout.flush()
    os.system("lscpu")


if __name__ == "__main__":
    save_path = "../Data/Supervised_Clustering/"

    check_cpu()
    check_gpu()

    parser = argparse.ArgumentParser()
    parser.add_argument("--slurm_id", type=int, default=0)
    args = parser.parse_args()
    slurm_id = args.slurm_id
    print("\nSLURM_ARRAY_TASK_ID:", slurm_id, "\n")
    sys.stdout.flush()

    # read remaining experiment ids from text file
    with open(save_path + "remaining.txt", "r") as f:
        remaining_ids = [int(line.strip()) for line in f]

    if slurm_id in remaining_ids:
        print("\nPending experiment found.\n")

        # read dtype_dict from csv file
        dtype_dict = pd.read_csv(save_path + "dtype_dict.csv")
        dtype_dict = {row[0]: row[1] for row in dtype_dict.values}

        # read experiments from csv file
        experiments = pd.read_csv(save_path + "experiments.csv", dtype=dtype_dict)

        # ensure null values are consistent
        experiments = experiments.where(pd.notnull(experiments), None)

        # get experiment configuration
        experiment = experiments[experiments["slurm_id"] == slurm_id].iloc[0]

        # set save_path
        experiment["save_path"] = save_path

        # print experiment configuration
        print("\nExperiment configuration:\n")
        print(experiment)
        print("\n")

        # run experiment
        print("\nRunning experiment...\n")
        result = supervised_clustering(**experiment.to_dict())
        print("\nExperiment completed.\n")
    else:
        print("\nNo pending experiments found.\n")
