import argparse
import copy
import json
import os
import sys
from pathlib import Path
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch

from torch.utils.data import Subset

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj, subgraph
from torch_geometric.nn.conv import MessagePassing

# Add the parent directory to Python path to find local modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from datasetss import SyntheticGraphFeaturizer, build_dataset
from architectures import SEALNetwork

EXPLAINERS = [
    "SEAL"
]

def args_parser():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--seed", type=int, default=123, help="seed")
    parser.add_argument("--trials", type=int, default=1, help="number of trials")
    parser.add_argument("--save_path", type=str,required=True, help="path to save explanations")
    parser.add_argument("--explainer_path", type=str, required=True, help="path to save explainer")
    parser.add_argument("--batch_size", type=int, default=1, help="batch size")
    parser.add_argument(
        "--explainer_type",
        type=str,
        default="SEAL",
        required=True,
        choices=EXPLAINERS,
        help="explainer type",
    )
    args = parser.parse_args()
    return args


class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, edge_index, batch):
        data = Data(x=x, edge_index=edge_index, batch=batch)
        return self.model(data)[0]
def create_model(kwargs):
    models = {
        "SEAL": SEALNetwork,
    }
    model_name = kwargs.get("model", "")
    if model_name not in models:
        raise ValueError(f"Unknown model name: {model_name}")


    return models[model_name](kwargs)

def get_node_importance(node_mask):
    node_mask = node_mask.detach() 
    if len(node_mask.shape) == 2:
        node_mask = node_mask.sum(-1)
    assert len(node_mask.shape) == 1
    return node_mask



def main():
    args = args_parser()
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    loaded = torch.load(args.explainer_path, map_location=torch.device("cpu"), weights_only=False)
    loaded_args = loaded["args"]
    print(f"Loaded model with args: {loaded_args}")
    dataset_kwargs = {
        "data_set": loaded_args["data_set"],
        "mean": 0.0,
        "std": 1.0,
        "y_column": 'Y',
        "smiles_col": "Drug",
        "task": "classification",
        "split": loaded_args["split"] if "split" in loaded_args else 0,
    }

    _, _, dataset_test = build_dataset(dataset_kwargs)
    featurizer=SyntheticGraphFeaturizer(y_column='Y')

    test_set = featurizer(dataset_test, dataset_kwargs)
    dataloader_test = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)

    

    model_kwargs = {
        "model": "SEAL",
        "hidden_features": loaded['model_args']['hidden_features'],
        "input_features": test_set[0].x.shape[1],
        "drop": 0.0,
        "num_layers": loaded['model_args']['num_layers'],
        "task": loaded_args["task"],
        "number_of_clusters": 0,
        "regularize": 0,
        "weight_decay": 0,
    }
    explainer = create_model(model_kwargs)
    explainer.load_state_dict(loaded["state_dict"])
    explainer = explainer.to(device)
    explainer.eval()


    explanations = list()
    for batch in tqdm(dataloader_test):
        batch = batch.to(device)

        node_mask = list()

        out = explainer(batch, None)
        pred=out["output"]

        mask=torch.zeros(batch.x.shape[0], device=device)
        for i in range(batch.s.shape[0]):
            mask[i] = out["x_cluster_transformed"][0][batch.s[i].argmax().detach().cpu().item()]
        node_mask = mask
            
        num_nodes = 0
        for b in range(batch.batch.max() + 1):
            x_mask = batch.batch == b
            edge_index_mask = batch.batch[batch.edge_index[0]] == b
            edge_index = batch.edge_index[:, edge_index_mask] - num_nodes
            data = Data(
                x=batch.x[x_mask].detach().cpu(),
                edge_index=edge_index.detach().cpu(),
                y=batch.y[b].detach().cpu(),
                pred=pred[b].detach().cpu(),
                node_mask=node_mask[x_mask].detach().cpu() if args.explainer_type != "SEAL" else node_mask.detach().cpu(),
            )

            explanations.append(data)
            num_nodes += x_mask.sum()

    print(f"Finished with {len(explanations)}/{len(dataset_test)} explanations")
    torch.save(
            {
                "args": vars(args),
                "model_args": loaded_args,
                "explanations": explanations,
                "f1": loaded["f1"],
                "roc_auc": loaded["roc_auc"],
                "accuracy":loaded["accuracy"],
                "explainer": model_kwargs if args.explainer_type == "SEAL" else {},
            },
            args.save_path,
        )

    print(f"Saved to {args.save_path}")


if __name__ == "__main__":
    main()
