import argparse

import torch
from torch_geometric.loader import DataLoader

from datasets.data_loader import load_dataset
from utils.export_for_ui import export_for_ui

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def score_model_mask(m, d, mask):
    m.eval()
    pred = m(d.x, d.edge_index).argmax(dim=-1)
    acc = int((pred[mask] == d.y[mask]).sum()) / len(mask)
    return acc



def main():
    parser = argparse.ArgumentParser(
        description='Train and evaluate the DT-GNN',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument("--dataset", default="MUTAG", help="Name of the dataset to run the experiment on")
    parser.add_argument("--data_dir", default="datasets/data", help="Path to the dataset directory")
    parser.add_argument("--model_path", help="Path to the trained DT-GNN model")
    parser.add_argument("--batch_size", default=128, type=int)
    parser.add_argument("--state_space", default=10, type=int)
    parser.add_argument("--number_of_layers", default=5, type=int)
    args = parser.parse_args()

    params = {
        "state_space": args.state_space,
        "number_of_layers": args.number_of_layers,
        "gumbel": True,
        "skip_connection": True,
        "network": "mlp",
        "batch_size": args.batch_size,
        "data_dir": args.data_dir,
        "dataset_name": args.dataset
    }

    dataset_name = args.dataset
    dataset, dataset_args = load_dataset(dataset_name, args)

    def score_model(m, d, mask=None):
        m.eval()
        total_correct = 0
        for data in d:
            out = m(data.x, data.edge_index, data.batch)
            correct = int((out.argmax(-1) == data.y).sum())
            if not dataset_args["use_pooling"]:
                correct /= len(data.y)
            total_correct += correct
        acc = total_correct / len(d.dataset)
        return acc

    model = torch.load(args.model_path)
    export_for_ui(dataset_name, dataset, params, model, model.trees,
                  score_model if dataset_args["dataset_mask"] else score_model,
                  params["data_dir"], test_dataset=dataset,
                  loaders=[DataLoader(dataset, batch_size=params["batch_size"]),
                           DataLoader(dataset, batch_size=params["batch_size"]),
                           DataLoader(dataset, batch_size=params["batch_size"])])


if __name__ == "__main__":
    main()
