import argparse
from tqdm.notebook import tqdm
import sys

usr_dir = "none"

sys.path.insert(1, f"/home/{usr_dir}/workspace/AIM & XGKN/XGKN")

import importlib.util

spec1 = importlib.util.spec_from_file_location(
    "gkn_utils", f"/home/{usr_dir}/workspace/AIM & XGKN/XGKN/utils.py"
)
gkn_utils = importlib.util.module_from_spec(spec1)
spec1.loader.exec_module(gkn_utils)
from models import GKNetwork

from distances import compare_size, iou_distance, weighted_ged_distance
from torch_geometric.loader import DataLoader

from torch.utils.data import Subset
from torch_geometric.utils import dense_to_sparse

from utils import *

GKN_PATH = f"../XGKN"
METRICS_NAMES = list(METRICS.keys())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Model path")
    parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
    parser.add_argument("--split", type=int, default=-1, help="Split index")
    parser.add_argument("--seed", type=int, default=0, help="Seed")
    parser.add_argument(
        "--metrics",
        type=str,
        nargs="+",
        choices=METRICS_NAMES,
        default=METRICS_NAMES,
        help="Metrics to evaluate",
    )
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--node_mask_fn", type=str, default="elbow_softmax:0", help="Thresholding of nodes")
    parser.add_argument("--edge_mask_fn", type=str, default="none", help="Thresholding for edges")
    args = parser.parse_args()
    return args


def get_responses_and_logits(model, dataloader, forward_fn):
    logits, responses = list(), list()
    for data in dataloader:
        data = data.to(device)
        l, r = forward_fn(model, data)
        responses.append(r), logits.append(l)
    return torch.cat(responses, 0), torch.cat(logits)


def get_truncated_explanations(explanations, node_mask_fn, edge_mask_fn):
    def _get_truncated_explanations(e):
        subset, edge_index = extract_subset_and_edge_index(
            e.edge_index,
            len(e.x) - 1,
            e.node_mask.sum(1)[:-1],
            node_mask_fn,
            e.edge_mask.sum(1),
            edge_mask_fn,
        )
        edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True, num_nodes=len(e.x))
        y_pred = e.pred.argmax(keepdim=True)[0]
        return Data(edge_index=edge_index, x=e.x[subset], y=y_pred, nodes_subset=subset)

    return list(map(_get_truncated_explanations, explanations))


def get_prototypes(model, dataset):
    prototypes = list()
    all_x = torch.cat([data.x for data in dataset])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    possible_x = torch.unique(all_x, sorted=False, dim=0).to(device)

    for ker_layer in model.ker_layers:
        encoded_possible_x = ker_layer.encoder(possible_x)
        x = ker_layer.x_hidden()
        x_shape = x.shape
        x = x.reshape(-1, x.shape[-1])
        sim = encoded_possible_x @ x.T
        x = possible_x[sim.argmax(0)]
        x = x.reshape(x_shape[0], x_shape[1], -1)
        adj = ker_layer.adj_hidden()
        for i in range(adj.shape[0]):
            edge_index = dense_to_sparse(adj[i, :, :] > 0.0)[0]
            edge_attr = torch.zeros(edge_index.shape[1]).to(edge_index.device)
            for j in range(edge_index.shape[1]):
                edge_attr[j] = float(min(adj[i, edge_index[0, j], edge_index[1, j]], 1) > 0)
            prototypes.append(Data(edge_index=edge_index, x=x[i], edge_attr=edge_attr))
    return prototypes


def perturb_model(model, dataset, p_x=0.2, p_edge=0.2):
    model_copy = copy.deepcopy(model)
    model_copy.eval()
    all_x = torch.cat([data.x for data in dataset])
    for ker_layer in model_copy.ker_layers:
        x = ker_layer._x_hidden.detach()
        adj = ker_layer._adj_hidden.detach()
        if p_x > 0:
            x = x.permute((0, 2, 1))
            x_shape = x.shape
            x = x.reshape(-1, x.shape[-1])
            keys = torch.tensor(np.random.choice([0, 1], p=[1 - p_x, p_x], size=len(x))).bool()
            size = keys.sum().item()
            if size > 0:
                x[keys] = ker_layer.encoder(
                    all_x[np.random.randint(low=0, high=len(all_x), size=size)].to(x.device)
                )
            ker_layer._x_hidden.data = x.reshape(x_shape[0], x_shape[1], -1).permute((0, 2, 1))
        if p_edge > 0:
            adj_shape = adj.shape
            adj = adj.reshape(-1)
            keys = (
                torch.tensor(np.random.choice([0, 1], p=[1 - p_edge, p_edge], size=len(adj)))
                .bool()
                .to(adj.device)
            )
            adj[keys] = (adj[keys] < 0).float()
            ker_layer._adj_hidden.data = adj.reshape(adj_shape)
    return model_copy


def get_explanations(model, explainer, dataloader, forward_fn):
    explanations = list()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for data in dataloader:
        data = data.to(device)
        with torch.no_grad():
            logits, all_responses, response = forward_fn(model, data)
        pred = logits.argmax(dim=-1)
        importance = run_shap(explainer, response).to(device)
        importance = torch.stack([importance[pred[i]][i] for i in range(len(pred))], dim=0).to(device)
        response[response == 0] = 1e-36

        num_nodes = 0
        for b in tqdm(range(len(data.y)), total=len(data.y)):
            mask = data.batch == b
            im, r, ar = importance[b], response[b], all_responses[mask]
            im = (im / r) * ar
            x = data.x[mask]
            edge_index = data.edge_index[:, data.batch[data.edge_index[0]] == b] - num_nodes
            num_nodes += len(x)
            edge_mask = torch.zeros(edge_index.shape[1], im.shape[-1], device=device)
            node_mask = im
            e = Data(
                x=x.cpu(),
                edge_index=edge_index.cpu(),
                y=data.y[b : b + 1].cpu(),
                pred=logits[b : b + 1].cpu(),
                node_mask=node_mask.cpu(),
                edge_mask=edge_mask.cpu(),
            )
            explanations.append(e)
    return explanations


def main():
    args = args_parser()
    print(args)

    s = args.seed if args.split < 0 else args.split

    params = torch.load(args.model_path, map_location=torch.device("cpu"))
    model = GKNetwork(**params["model_args"]).to(device)
    model.load_state_dict(params["state_dict"], strict=False)
    model.eval()

    degree_attr = "_ndl_" in args.model_path
    dataset = get_dataset(args.dataset, degree_attr=degree_attr, use_node_attr=True)
    k, subgraph_size = params["args"]["k"], params["args"]["subgraph_size"]

    def preprocess_graph(data):
        return gkn_utils.transform(data, k, subgraph_size, degree_attr)

    def preprocess_dataset(dataset):
        return list(map(preprocess_graph, dataset))

    train_idxs, val_idxs, test_idxs = get_splits(
        args.dataset, size=len(dataset), seed=args.seed, split=args.split
    )

    dataset_train, dataset_val, init_dataset_test = (
        Subset(dataset, train_idxs),
        Subset(dataset, val_idxs),
        Subset(dataset, test_idxs),
    )
    dataset_test = preprocess_dataset(init_dataset_test)
    dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False)

    def forward_fn(model, data):
        return model(data)[0]

    results = {}

    node_mask_fn = get_fn(args.node_mask_fn)
    edge_mask_fn = get_fn(args.edge_mask_fn)

    if "A" in args.metrics:
        results["A"] = get_acc(model, dataloader_test, forward_fn)
        print_results(results)

    def logits_and_responses_fn(model, data):
        logits, _, responses, _ = model(data)
        return logits, responses

    def logits_and_all_responses_fn(model, data):
        logits, _, responses, all_responses = model(data)
        return logits, all_responses[0][0], responses

    dataloader = DataLoader(preprocess_dataset(dataset), batch_size=args.batch_size, shuffle=False)
    responses, logits = get_responses_and_logits(model, dataloader, logits_and_responses_fn)

    shap_values, explainer = get_shap(model.mlp, responses)

    explanations = get_explanations(model, explainer, dataloader_test, logits_and_all_responses_fn)
    truncated_explanations = get_truncated_explanations(explanations, node_mask_fn, edge_mask_fn)
    prototypes = get_prototypes(model, dataset)

    if "A1" in args.metrics:
        if args.dataset in ["MUTAG", "BA-2motif", "BAMultiShapes"]:
            distances = compare_with_gt_instance(
                args.dataset, dataset_test, truncated_explanations, iou_distance
            )
            results["A1"] = 1 - np.mean(distances)
        else:
            results["A1"] = None
        print_results(results)

    if "A2" in args.metrics:
        if args.dataset in ["MUTAG", "BA-2motif", "BAMultiShapes"]:
            ground_truths = get_gt_explanations_model(args.dataset)
            distances = compare_with_gt_model(prototypes, ground_truths, weighted_ged_distance)
            results["A2"] = 1 - np.mean(distances)
        else:
            results["A2"] = None
        print_results(results)

    if bool({"I1", "I2"} & set(args.metrics)):
        with_explanations, without_explanations = remove_explanation_perturb(
            explanations, node_mask_fn, edge_mask_fn, skip_last_node=False
        )
        if "I1" in args.metrics:
            with_explanations = [data for data in with_explanations if data is not None]
            with_explanations = preprocess_dataset(with_explanations)
            dataloader_with = DataLoader(with_explanations, batch_size=args.batch_size, shuffle=False)
            results["I1"] = get_acc(model, dataloader_with, forward_fn)
            print_results(results)
        if "I2" in args.metrics:
            without_explanations = [data for data in without_explanations if data is not None]
            without_explanations = preprocess_dataset(without_explanations)
            dataloader_without = DataLoader(without_explanations, batch_size=args.batch_size, shuffle=False)
            acc_without = 1 - get_acc(model, dataloader_without, forward_fn)
            results["I2"] = acc_without * len(without_explanations) / len(explanations)
            print_results(results)

    if "I3" in args.metrics:
        new_explanations = get_explanations(model, explainer, dataloader_test, logits_and_all_responses_fn)
        new_truncated_explanations = get_truncated_explanations(new_explanations, node_mask_fn, edge_mask_fn)
        distances = compare_lists(truncated_explanations, new_truncated_explanations, iou_distance)
        results["I3"] = 1 - np.mean(distances)
        print_results(results)

    if "I4" in args.metrics:
        dataset_noisy_node = add_noise_perturb(init_dataset_test, p_x=0.05, p_edges_add=0.0, p_edges_del=0.0, explanations=truncated_explanations)
        dataloader_new = DataLoader(
            preprocess_dataset(dataset_noisy_node),
            batch_size=args.batch_size,
            shuffle=False,
        )
        new_explanations = get_explanations(model, explainer, dataloader_new, logits_and_all_responses_fn)
        new_truncated_explanations = get_truncated_explanations(new_explanations, node_mask_fn, edge_mask_fn)
        distances = compare_lists(truncated_explanations, new_truncated_explanations, iou_distance)
        results["I4"] = 1 - np.mean(distances)
        print_results(results)

    if "I5" in args.metrics:
        dataset_noisy_edge = add_noise_perturb(
            init_dataset_test, p_x=0.0, p_edges_add=0.005, p_edges_del=0.005, skip_last=False, explanations=truncated_explanations
        )
        dataloader_new = DataLoader(
            preprocess_dataset(dataset_noisy_edge),
            batch_size=args.batch_size,
            shuffle=False,
        )
        new_explanations = get_explanations(model, explainer, dataloader_new, logits_and_all_responses_fn)
        new_truncated_explanations = get_truncated_explanations(new_explanations, node_mask_fn, edge_mask_fn)
        distances = compare_lists(truncated_explanations, new_truncated_explanations, iou_distance)
        results["I5"] = 1 - np.mean(distances)
        print_results(results)

    if "I6" in args.metrics:
        dataloader_train = DataLoader(
            preprocess_dataset(dataset_train), batch_size=args.batch_size, shuffle=True
        )
        dataloader_val = DataLoader(
            preprocess_dataset(dataset_val), batch_size=args.batch_size, shuffle=False
        )

        explanations_train = get_explanations(
            model,
            explainer,
            dataloader_train,
            logits_and_all_responses_fn,
        )
        truncated_explanations_train = get_truncated_explanations(
            explanations_train, node_mask_fn, edge_mask_fn
        )
        explanations_val = get_explanations(model, explainer, dataloader_val, logits_and_all_responses_fn)
        truncated_explanations_val = get_truncated_explanations(explanations_val, node_mask_fn, edge_mask_fn)
        new_dataloader_train = DataLoader(
            truncated_explanations_train, batch_size=args.batch_size, shuffle=True
        )
        new_dataloader_val = DataLoader(truncated_explanations_val, batch_size=args.batch_size, shuffle=False)
        new_dataloader_test = DataLoader(truncated_explanations, batch_size=args.batch_size, shuffle=False)
        model_args = torch.load(f"models/{args.dataset}_s{s}.pt", map_location=torch.device("cpu"))["args"]
        model_args["num_node_features"] = dataset.num_node_features
        new_model = get_model(**model_args).to(device)
        new_model = train_model(
            new_model,
            new_dataloader_train,
            new_dataloader_val,
            new_dataloader_test,
            epochs=500 if args.dataset == "MUTAG" else 200,
            lr=0.001,
            weight_decay=0.005,
        )
        results["I6"] = get_acc(
            new_model,
            new_dataloader_test,
            lambda model, data: model(data.x, data.edge_index, batch=data.batch),
        )
        print_results(results)

    if "I7" in args.metrics:
        sizes = compare_lists(dataset_test, truncated_explanations, compare_size)
        results["I7"] = 1 - np.mean(sizes)
        print_results(results)

    if "M1" in args.metrics:
        perturbed_model = perturb_model(model, dataset, p_x=0.5, p_edge=0)
        perturbed_responses, perturbed_logits = get_responses_and_logits(
            perturbed_model, dataloader, logits_and_responses_fn
        )
        _, perturbed_explainer = get_shap(perturbed_model.mlp, perturbed_responses)

        perturbed_explanations = get_explanations(
            perturbed_model,
            perturbed_explainer,
            dataloader_test,
            logits_and_all_responses_fn,
        )
        perturbed_explanations = get_truncated_explanations(
            perturbed_explanations, node_mask_fn, edge_mask_fn
        )
        distances = compare_lists(truncated_explanations, perturbed_explanations, iou_distance)
        results["M1"] = np.mean(distances)
        print_results(results)

    if "M2" in args.metrics:
        perturbed_model = perturb_model(model, dataset, p_x=0, p_edge=0.25)
        perturbed_responses, perturbed_logits = get_responses_and_logits(
            perturbed_model, dataloader, logits_and_responses_fn
        )
        _, perturbed_explainer = get_shap(perturbed_model.mlp, perturbed_responses)
        perturbed_explanations = get_explanations(
            perturbed_model,
            perturbed_explainer,
            dataloader_test,
            logits_and_all_responses_fn,
        )
        perturbed_explanations = get_truncated_explanations(
            perturbed_explanations, node_mask_fn, edge_mask_fn
        )
        distances = compare_lists(truncated_explanations, perturbed_explanations, iou_distance)
        results["M2"] = np.mean(distances)
        print_results(results)

    if "M3" in args.metrics:
        results["M3"] = 1 - pairwise_list_corr(responses.cpu().detach().numpy(), to_abs=True)
        print_results(results)

    print("Final")
    print_results(results)


if __name__ == "__main__":
    main()
