import argparse
import copy
import json
import os
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch

from torch.utils.data import Subset
from torch_geometric.data import Data, Batch
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj, subgraph
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv import GINConv


from sklearn.metrics import roc_auc_score

from torch_geometric.explain import (
    Explainer,
    GNNExplainer,
    CaptumExplainer,
)

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from datasetss.dataset import GraphFeaturizer, build_dataset

from train_model import get_model, test

import sys
#INCLUDE PATHS TO PROTGNN
from libs.ProtGNN.Configures import model_args
from libs.ProtGNN.models import GnnNets


EXPLAINERS = [
    "GNNExplainer",
    "IntegratedGradients",
    "Saliency",
    "InputXGradient",
    "Deconvolution",
    "GuidedBackprop",
    "PGExplainer",
]

def args_parser():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--seed", type=int, default=123, help="seed")
    parser.add_argument("--trials", type=int, default=1, help="number of trials")
    parser.add_argument("--save_path", type=str ,required=True, help="path to save explanations")
    parser.add_argument("--batch_size", type=int, default=1, help="batch size")
    parser.add_argument("--model_path", type=str, required=True, help="model path")
    parser.add_argument(
        "--explainer_type",
        type=str,
        required=True,
        choices=EXPLAINERS,
        help="explainer type",
    )
    parser.add_argument("--explanation_type", type=str, default="model", required=False, choices=["model", "phenomenon"])
    parser.add_argument(
        "--node_mask_type",
        type=str,
        default="attributes",
        required=False,
        choices=["object", "none", "attributes", "common_attributes"],
    )
    parser.add_argument(
        "--edge_mask_type",
        type=str,
        required=False,
        choices=["object", "none", "attributes", "common_attributes"],
    )
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
    parser.add_argument("--epochs", type=int, default=150, help="number of epochs")
    parser.add_argument("--save_all", type=bool, default=False)
    args = parser.parse_args()
    return args


class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, edge_index, batch):
        data = Data(x=x, edge_index=edge_index, batch=batch)
        return self.model(data)[0]

def get_node_importance(node_mask):
    node_mask = node_mask.detach() 
    if len(node_mask.shape) == 2:
        node_mask = node_mask.sum(-1)
    assert len(node_mask.shape) == 1
    return node_mask



def main():
    args = args_parser()
    if args.node_mask_type == "none":
        args.node_mask_type = None
    if args.edge_mask_type == "none":
        args.edge_mask_type = None
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    loaded = torch.load(args.model_path, map_location=torch.device("cpu"), weights_only=False)
    loaded_args = loaded["args"]
    if loaded['args']['data_set'] == "sol":
        dataset_kwargs = {
            "data_set": loaded_args["data_set"],
            "mean": -2.86,
            "std": 2.38,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": loaded_args["task"],
            "split": loaded_args["split"] if "split" in loaded_args else 0,
        }
    else:
        dataset_kwargs = {
            "data_set": loaded_args["data_set"],
            "mean": 0.0,
            "std": 1.0,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": loaded_args["task"],
            "split": loaded_args["split"] if "split" in loaded_args else 0,
        }


    _, _, dataset_test = build_dataset(dataset_kwargs)
    featurizer=GraphFeaturizer(y_column='Y')

    test_set = featurizer(dataset_test, dataset_kwargs)
    dataloader_test = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)


    try:
        model = get_model(**loaded["model_args"]).to(device)
        model.load_state_dict(loaded["state_dict"])
        num_layers = loaded["model_args"]["num_layers"]
        # Setting up params used by an explainer
        for module in model.modules():
            if isinstance(module, MessagePassing):
                if not hasattr(module, "in_channels"):
                    print(module.__dict__)
                    channel_list = module.nn.channel_list
                    module.in_channels = channel_list[0]
                    module.out_channels = channel_list[-1]
    except:
        model_args.model_name = loaded["args"]["model_type"]
        model_args.readout = loaded["args"]["readout"]
        input_dim = 11
        output_dim = 2
        model_args.num_prototypes_per_class=3
        model_args.latent_dim = [loaded["args"]["hidden"]] * loaded["args"]["num_layers"]
        model_args.latent_dim[-1]=128
        protgnn = GnnNets(input_dim, output_dim, model_args)
        
        for layer in protgnn.model.gnn_layers:
            if isinstance(layer, GINConv):
                layer.in_channels = layer.nn[0].in_features
                layer.out_channels = layer.nn[0].out_features
        num_layers = len(protgnn.model.gnn_layers)
        protgnn.load_state_dict(loaded["state_dict"])
        model = ModelWrapper(protgnn.model).to(device)
    model.eval()

    _, test_metric = test(model, dataloader_test, device, task=loaded_args["task"])
    print(f"Test metric: {test_metric}")
    model_config = dict(mode="multiclass_classification", task_level="graph", return_type="raw") if loaded_args["task"] == "classification" else dict(mode="regression", task_level="graph", return_type="raw")
    if args.explainer_type == "GNNExplainer":
        algorithm = GNNExplainer(epochs=args.epochs, lr=args.lr).to(device)
    else:
        algorithm = CaptumExplainer(args.explainer_type)
    explainer = Explainer(
        model=model,
        algorithm=algorithm,
        explanation_type=args.explanation_type,
        node_mask_type=args.node_mask_type,
        edge_mask_type=args.edge_mask_type,
        model_config=model_config,
    )

    explanations = list()
    for batch in tqdm(dataloader_test):
        batch = batch.to(device)
        pred = model(batch.x, batch.edge_index, batch.batch)
        target = batch.y if args.explanation_type != "model" else None
        node_mask = list()
        edge_mask = list()

        if args.explanation_type != "model":
            e = explainer(batch.x, batch.edge_index, batch=batch.batch, target=target.long() if dataset_kwargs["task"] == "classification" else target.float())
        else:
            e = explainer(batch.x, batch.edge_index, batch=batch.batch)
        if args.explainer_type != "PGExplainer":
            mask = get_node_importance(e.node_mask) 
        else:
            mask = torch.zeros(batch.x.size(0), device=batch.x.device)
            src, dst = batch.edge_index
            mask.index_add_(0, src, e.edge_mask)
            mask.index_add_(0, dst, e.edge_mask)

        node_mask = mask

        # node_mask = torch.stack(node_mask)
        # node_mask = node_mask.mean(dim=0)
    

        # if (node_mask is not None) and (node_mask.shape[0] != batch.expl_node_mask.shape[0]):
            # assert edge_mask is not None
            # node_mask = None

        num_nodes = 0
        for b in range(batch.batch.max() + 1):
            x_mask = batch.batch == b
            edge_index_mask = batch.batch[batch.edge_index[0]] == b
            edge_index = batch.edge_index[:, edge_index_mask] - num_nodes
            data = Data(
                x=batch.x[x_mask].detach().cpu(),
                edge_index=edge_index.detach().cpu(),
                y=batch.y[b].detach().cpu(),
                pred=pred[b].detach().cpu(),
                node_mask=node_mask[x_mask].detach().cpu() 
            )

            explanations.append(data)
            num_nodes += x_mask.sum()

    print(f"Finished with {len(explanations)}/{len(dataset_test)} explanations")
    if loaded_args["task"] == "classification":
        torch.save(
            {
                "args": vars(args),
                "model_args": loaded_args,
                "explanations": explanations,
                "f1": loaded["f1"],
                "roc_auc": loaded["roc_auc"],
                "accuracy":loaded["accuracy"],
                "explainer": {},
            },
            args.save_path,)
    else:
        torch.save(
            {
                "args": vars(args),
                "model_args": loaded_args,
                "explanations": explanations,
                "mae": loaded["mae"],
                "rmse": loaded["rmse"],
                "explainer": {},
            },
            args.save_path,)
    print(f"Saved to {args.save_path}")


if __name__ == "__main__":
    main()
