import argparse
import copy
import json
import os
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch

from torch.utils.data import Subset
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from datasetss.dataset import GraphFeaturizer, build_dataset

from sklearn.metrics import roc_auc_score


import sys
from architectures.SEAL import SEALNetwork
EXPLAINERS = [
    "SEAL"
]

def args_parser():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--seed", type=int, default=123, help="seed")
    parser.add_argument("--trials", type=int, default=1, help="number of trials")
    parser.add_argument("--save_path", type=str,required=True, help="path to save explanations")
    parser.add_argument("--explainer_path", type=str, required=True, help="path to explainer")
    parser.add_argument("--batch_size", type=int, default=1, help="batch size")
    parser.add_argument(
        "--explainer_type",
        type=str,
        required=False,
        default="SEAL",
        choices=EXPLAINERS,
        help="explainer type",
    )

    parser.add_argument("--save_all", type=bool, default=False)
    args = parser.parse_args()
    return args


class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, edge_index, batch):
        data = Data(x=x, edge_index=edge_index, batch=batch)
        return self.model(data)[0]

def get_node_importance(node_mask):
    node_mask = node_mask.detach() 
    if len(node_mask.shape) == 2:
        node_mask = node_mask.sum(-1)
    assert len(node_mask.shape) == 1
    return node_mask


def create_model(kwargs):
    models = {
        "SEAL": SEALNetwork,
    }
    model_name = kwargs.get("model", "")
    if model_name not in models:
        raise ValueError(f"Unknown model name: {model_name}")


    return models[model_name](kwargs)


def main():
    args = args_parser()
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    loaded = torch.load(args.explainer_path, map_location=torch.device("cpu"), weights_only=False)
    loaded_args = loaded["args"]
    if loaded['args']['data_set'] == "sol":
        dataset_kwargs = {
            "data_set": loaded_args["data_set"],
            "mean": -2.86,
            "std": 2.38,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": loaded_args["task"],
            "split": loaded_args["split"] if "split" in loaded_args else 0,
        }
    else:
        dataset_kwargs = {
            "data_set": loaded_args["data_set"],
            "mean": 0.0,
            "std": 1.0,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": loaded_args["task"],
            "split": loaded_args["split"] if "split" in loaded_args else 0,
        }


    _, _, dataset_test = build_dataset(dataset_kwargs)
    featurizer=GraphFeaturizer(y_column='Y')

    test_set = featurizer(dataset_test, dataset_kwargs)
    dataloader_test = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)


    

    model_kwargs = {
        "model": "SEAL",
        "hidden_features": loaded['model_args']['hidden_features'],
        "input_features": test_set[0].x.shape[1],
        "drop": 0.0,
        "num_layers": loaded['model_args']['num_layers'],
        "task": loaded_args["task"],
        "number_of_clusters": 0,
        "regularize": 0,
        "weight_decay": 0,
    }
    explainer = create_model(model_kwargs)
    explainer.load_state_dict(loaded["state_dict"])
    explainer = explainer.to(device)
    explainer.eval()



    explanations = list()
    for batch in tqdm(dataloader_test):
        batch = batch.to(device)
        node_mask = list()

        out = explainer(batch, None)
        mask=torch.zeros(batch.x.shape[0], device=device)
        for i in range(batch.s.shape[0]):
            mask[i] = out["x_cluster_transformed"][0][batch.s[i].argmax().detach().cpu().item()]
        node_mask = mask
        pred=out["output"]
        num_nodes = 0
        for b in range(batch.batch.max() + 1):
            x_mask = batch.batch == b
            edge_index_mask = batch.batch[batch.edge_index[0]] == b
            edge_index = batch.edge_index[:, edge_index_mask] - num_nodes
            data = Data(
                x=batch.x[x_mask].detach().cpu(),
                edge_index=edge_index.detach().cpu(),
                y=batch.y[b].detach().cpu(),
                pred=pred[b].detach().cpu(),
                node_mask=node_mask[x_mask].detach().cpu() if args.explainer_type != "SEAL" else node_mask.detach().cpu(),
            )

            explanations.append(data)
            num_nodes += x_mask.sum()

    print(f"Finished with {len(explanations)}/{len(dataset_test)} explanations")
    if loaded_args["task"] == "classification":
        torch.save(
            {
                "args": vars(args),
                "model_args": loaded_args,
                "explanations": explanations,
                "f1": loaded["f1"],
                "roc_auc": loaded["roc_auc"],
                "accuracy":loaded["accuracy"],
                "explainer": model_kwargs if args.explainer_type == "SEAL" else {},
            },
            args.save_path,
        )
    else:
        torch.save(
            {
                "args": vars(args),
                "model_args": loaded_args,
                "explanations": explanations,
                "mae": loaded["mae"],
                "rmse": loaded["rmse"],
                "explainer": model_kwargs if args.explainer_type == "SEAL" else {},
            },
            args.save_path,
        )
    print(f"Saved to {args.save_path}")


if __name__ == "__main__":
    main()
