import argparse

from torch.utils.data import Subset

from torch_geometric.explain import Explainer, GNNExplainer, PGExplainer
from torch_geometric.loader import DataLoader


from distances import compare_size, iou_distance
from utils import *

METRICS_NAMES = list(METRICS.keys())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
    parser.add_argument("--split", type=int, default=-1, help="Split index")
    parser.add_argument("--seed", type=int, default=0, help="Seed")
    parser.add_argument(
        "--explainer",
        type=str,
        required=True,
        choices=["GNNExplainer", "PGExplainer"],
        help="Type of explainer",
    )
    parser.add_argument(
        "--metrics",
        type=str,
        nargs="+",
        choices=METRICS_NAMES,
        default=METRICS_NAMES,
        help="Metrics to evaluate",
    )
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--node_mask_fn", type=str, default="", help="Thresholding of nodes")
    parser.add_argument("--edge_mask_fn", type=str, default="", help="Thresholding of edges")
    parser.add_argument("--features_mask_fn", type=str, default="", help="Thresholding of features")
    parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
    parser.add_argument("--lr", type=int, default=0.1, help="Learning rate")

    args = parser.parse_args()
    return args


def get_explanations(model, dataloader, args):
    explanations = list()
    if args.explainer == "PGExplainer":
        explainer = Explainer(
            model=model,
            algorithm=PGExplainer(epochs=args.epochs, lr=args.lr).to(device), # lr 0.003
            explanation_type="phenomenon",
            edge_mask_type="object",
            model_config=dict(
                mode="multiclass_classification",
                task_level="graph",
                return_type="raw",
            ),
        )
        for epoch in range(100):
            for data in dataloader:
                data = data.to(device)
                pred = model(data.x, data.edge_index, data.batch)
                explainer.algorithm.train(
                    epoch, model, data.x, data.edge_index, batch=data.batch, target=pred.argmax(dim=-1)
                )
    elif args.explainer == "GNNExplainer":
        explainer = Explainer(
            model=model,
            algorithm=GNNExplainer(epochs=args.epochs, lr=args.lr),
            explanation_type="model",
            node_mask_type="object",
            edge_mask_type="object",
            model_config=dict(
                mode="multiclass_classification",
                task_level="graph",
                return_type="raw",
            ),
        )
    else:
        assert False, args.explainer

    for data in dataloader:
        data = data.to(device)
        target = data.y if args.explainer == "PGExplainer" else None
        explanation = explainer(data.x, data.edge_index, batch=data.batch, target=target)
        explanation.y = explanation.target
        explanation.pred = model(data.x, data.edge_index, data.batch).cpu().detach()
        num_nodes = 0
        for b in range(explanation.batch.max() + 1):
            mask = explanation.batch == b
            data = Data(
                x=explanation.x[mask],
                edge_index=explanation.edge_index[:, explanation.batch[explanation.edge_index[0]] == b]
                - num_nodes,
                y=explanation.y[b],
                pred=explanation.pred[b],
                node_mask=explanation.node_mask[mask] if hasattr(explanation, "node_mask") else None,
                edge_mask=(
                    explanation.edge_mask[explanation.batch[explanation.edge_index[0]] == b]
                    if explanation.edge_mask is not None
                    else None
                ),
                mask=explanation.mask[mask] if hasattr(explanation, "mask") else None,
            )
            explanations.append(data)
            num_nodes += mask.sum()
    return explanations


def remove_explanation_perturb(explanations, node_mask_fn, edge_mask_fn, features_mask_fn):
    default_features = torch.cat([e.x for e in explanations]).mean(0)

    def _remove_explanation_perturb(e):
        node_mask = e.node_mask.sum(1) if hasattr(e, "node_mask") else None
        subset, edge_index = extract_subset_and_edge_index(
            e.edge_index,
            len(e.x),
            node_mask,
            node_mask_fn,
            e.edge_mask,
            edge_mask_fn,
        )
        edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True, num_nodes=len(e.x))

        with_x, without_x = e.x, e.x
        if hasattr(e, "mask"):
            thr = features_mask_fn(e.mask)
            if thr is not None:
                features_keep = torch.arange(len(e.mask))[e.mask >= thr]
                features_other = [i for i in range(len(e.mask)) if i not in features_keep]
                for i in features_other:
                    with_x[:, i] = default_features[i]
                for i in features_keep:
                    without_x[:, i] = default_features[i]

        y_pred = e.pred.argmax(keepdim=True)[0]
        if len(subset) > 0:
            with_explanation = Data(edge_index=edge_index, x=with_x[subset], y=y_pred)
        else:
            with_explanation = None
        mask = torch.ones(len(e.x), dtype=torch.bool, device=device)
        mask[subset] = False
        subset = torch.arange(len(e.x), device=device)[mask]
        if len(subset) == 0:
            without_explanation = None
        else:
            edge_index, _ = subgraph(subset, e.edge_index, relabel_nodes=True, num_nodes=len(e.x))
            without_explanation = Data(edge_index=edge_index, x=without_x[subset], y=y_pred)
        return with_explanation, without_explanation

    with_without_dataset = list(map(_remove_explanation_perturb, explanations))
    with_without_dataset = zip(*with_without_dataset)
    return with_without_dataset


def get_truncated_explanations(explanations, node_mask_fn, edge_mask_fn, features_mask_fn):
    def _get_truncated_explanations(e):
        node_mask = e.node_mask.sum(1) if hasattr(e, "node_mask") else None
        subset, edge_index = extract_subset_and_edge_index(
            e.edge_index,
            len(e.x),
            node_mask,
            node_mask_fn,
            e.edge_mask,
            edge_mask_fn,
        )
        edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True, num_nodes=len(e.x))

        x = e.x
        if hasattr(e, "mask"):
            thr = features_mask_fn(e.mask)
            if thr is not None:
                features_keep = torch.arange(len(e.mask))[e.mask >= thr]
                x = x[:, features_keep]

        y_pred = e.pred.argmax(keepdim=True)[0]
        assert len(subset) > 0
        return Data(edge_index=edge_index, x=x[subset], y=y_pred, nodes_subset=subset)

    return list(map(_get_truncated_explanations, explanations))


def main():
    args = args_parser()
    print(args)
    set_seed(args.seed)
    degree_attr = False  # args.dataset in ["REDDIT-BINARY", "IMDB-BINARY", "IMDB-MULTI"]
    dataset = get_dataset(args.dataset, degree_attr=degree_attr)
    train_idxs, val_idxs, test_idxs = get_splits(
        args.dataset, size=len(dataset), seed=args.seed, split=args.split
    )

    dataset_train, dataset_val, dataset_test = (
        Subset(dataset, train_idxs),
        Subset(dataset, val_idxs),
        Subset(dataset, test_idxs),
    )

    dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False)

    s = args.seed if args.split < 0 else args.split
    params = torch.load(f"models/{args.dataset}_s{s}.pt", map_location=torch.device("cpu"))
    model = get_model(**params["args"]).to(device)
    model.load_state_dict(params["state_dict"])
    model.eval()
    forward_fn = lambda model, data: model(data.x, data.edge_index, batch=data.batch)
    results = {}

    node_mask_fn = get_fn(args.node_mask_fn)
    edge_mask_fn = get_fn(args.edge_mask_fn)
    features_mask_fn = get_fn(args.features_mask_fn)

    if "A" in args.metrics:
        results["A"] = get_acc(model, dataloader_test, forward_fn)
        print_results(results)

    explanations = get_explanations(model, dataloader_test, args)
    truncated_explanations = get_truncated_explanations(
        explanations, node_mask_fn, edge_mask_fn, features_mask_fn
    )

    if "A1" in args.metrics:
        distances = compare_with_gt_instance(args.dataset, dataset_test, truncated_explanations, iou_distance)
        if len(distances) == 0:
            results["A1"] = None
        else:
            results["A1"] = 1 - np.mean(distances)
        print_results(results)

    if "A2" in args.metrics:
        results["A2"] = None
        print_results(results)

    if bool({"I1", "I2"} & set(args.metrics)):
        with_explanations, without_explanations = remove_explanation_perturb(
            explanations, node_mask_fn, edge_mask_fn, features_mask_fn
        )
        if "I1" in args.metrics:
            with_explanations = [data for data in with_explanations if data is not None]
            dataloader_with = DataLoader(with_explanations, batch_size=args.batch_size, shuffle=False)
            results["I1"] = get_acc(model, dataloader_with, forward_fn)
            print_results(results)
        if "I2" in args.metrics:
            without_explanations = [data for data in without_explanations if data is not None]
            dataloader_without = DataLoader(without_explanations, batch_size=args.batch_size, shuffle=False)
            acc_without = 1 - get_acc(model, dataloader_without, forward_fn)
            results["I2"] = acc_without * len(without_explanations) / len(explanations)
            print_results(results)

    if "I3" in args.metrics:
        new_explanations = get_explanations(model, dataloader_test, args)
        new_truncated_explanations = get_truncated_explanations(
            new_explanations, node_mask_fn, edge_mask_fn, features_mask_fn
        )
        distances = compare_lists(truncated_explanations, new_truncated_explanations, iou_distance)
        results["I3"] = 1 - np.mean(distances)
        print_results(results)

    if "I4" in args.metrics:
        dataset_noisy_node = add_noise_perturb(dataset_test, p_x=0.05, p_edges_add=0.0, p_edges_del=0.0, explanations=truncated_explanations)
        dataloader_new = DataLoader(dataset_noisy_node, batch_size=args.batch_size, shuffle=False)
        new_explanations = get_explanations(model, dataloader_new, args)
        new_truncated_explanations = get_truncated_explanations(
            new_explanations, node_mask_fn, edge_mask_fn, features_mask_fn
        )
        distances = compare_lists(truncated_explanations, new_truncated_explanations, iou_distance)
        results["I4"] = 1 - np.mean(distances)
        print_results(results)

    if "I5" in args.metrics:
        dataset_noisy_edge = add_noise_perturb(dataset_test, p_x=0.0, p_edges_add=0.005, p_edges_del=0.005, explanations=truncated_explanations)
        dataloader_new = DataLoader(dataset_noisy_edge, batch_size=args.batch_size, shuffle=False)
        new_explanations = get_explanations(model, dataloader_new, args)
        new_truncated_explanations = get_truncated_explanations(
            new_explanations, node_mask_fn, edge_mask_fn, features_mask_fn
        )
        distances = compare_lists(truncated_explanations, new_truncated_explanations, iou_distance)
        results["I5"] = 1 - np.mean(distances)
        print_results(results)

    if "I6" in args.metrics:
        dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False)
        dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False)

        explanations_train = get_explanations(model, dataloader_train, args)
        truncated_explanations_train = get_truncated_explanations(
            explanations_train, node_mask_fn, edge_mask_fn, features_mask_fn
        )
        explanations_val = get_explanations(model, dataloader_val, args)
        truncated_explanations_val = get_truncated_explanations(
            explanations_val, node_mask_fn, edge_mask_fn, features_mask_fn
        )
        new_dataloader_train = DataLoader(
            truncated_explanations_train, batch_size=args.batch_size, shuffle=True
        )
        new_dataloader_val = DataLoader(truncated_explanations_val, batch_size=args.batch_size, shuffle=False)
        new_dataloader_test = DataLoader(truncated_explanations, batch_size=args.batch_size, shuffle=False)
        model_args = params["args"]
        new_model = get_model(**model_args).to(device)
        new_model = train_model(
            new_model,
            new_dataloader_train,
            new_dataloader_val,
            new_dataloader_test,
            epochs=200,
            lr=0.001,
            weight_decay=0.005,
        )
        new_model.eval()
        results["I6"] = get_acc(new_model, new_dataloader_test, forward_fn)
        print_results(results)

    if "I7" in args.metrics:
        sizes = compare_lists(dataset_test, truncated_explanations, compare_size)
        results["I7"] = 1 - np.mean(sizes)
        print_results(results)

    if "M1" in args.metrics:
        results["M1"] = None
    if "M2" in args.metrics:
        results["M2"] = None
    if "M3" in args.metrics:
        results["M3"] = None

    print("Final")
    print_results(results)


if __name__ == "__main__":
    main()
