import argparse
import copy
import json
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

import torch
from torch.utils.data import Subset
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from torch_geometric.data import Data, Batch, DataLoader
from torch_geometric.utils import to_dense_adj, subgraph

from sklearn.metrics import roc_auc_score, average_precision_score

from datasetss.dataset import build_dataset, SyntheticGraphFeaturizer


def args_parser():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--seed", type=int, default=123, help="seed")
    parser.add_argument("--save_path", type=str,required=True, help="path to csv save results")
    parser.add_argument("--explanations_path", type=str, required=True, help="path to explanations")
    args = parser.parse_args()
    return args




def get_node_importance(e, data):
    node_mask = e.node_mask.detach() if hasattr(e, "node_mask") else None
    node_mask = torch.nan_to_num(node_mask)
    if len(node_mask.shape) == 2:
        node_mask = node_mask.sum(-1)
    assert len(node_mask.shape) == 1
    return node_mask





def no_outliers(x):
    Q1 = torch.quantile(x, 0.25)
    Q3 = torch.quantile(x, 0.75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    outliers = x[(x < lower_bound) | (x > upper_bound)]
    return float(len(outliers) == 0)


def main():
    args = args_parser()
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)

    loaded = torch.load(args.explanations_path, weights_only=False)
    task = loaded["model_args"]["task"]


    dataset_kwargs = {
        "data_set": loaded["model_args"]["data_set"],
        "mean": 0.0,
        "std": 1.0,
        "y_column": 'Y',
        "smiles_col": "Drug",
        "task": "classification",
        "split": loaded["model_args"]["split"],
    }
    _,_, dataset_test = build_dataset(dataset_kwargs)
    featurizer=SyntheticGraphFeaturizer(y_column='Y')

    test_set = featurizer(dataset_test, dataset_kwargs)
    dataloader_test = DataLoader(test_set, batch_size=1, shuffle=False)
    explanations = loaded["explanations"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    non_empty_ex, empty_ex = list(), list()
    metric_fn = roc_auc_score
    print(f"Task: {task} Eval metric: {metric_fn.__name__}")

    non_empty_ex, empty_ex = list(), list()

    for e, data in tqdm(zip(explanations, dataloader_test), total=len(explanations)):
        assert torch.equal(data.x, e.x) and torch.equal(data.edge_index, e.edge_index), "Data and explanation do not match"
        assert data.y.item()==e.y.item(), "Data and explanation do not match"
        assert data.explanation.reshape(-1).shape[0] == e.node_mask.reshape(-1).shape[0], "Data and explanation do not match"
        pred_mask = get_node_importance(e, data)
        gt = data.explanation

        if gt.min() == gt.max():
            m = no_outliers(pred_mask)
            empty_ex.append(m)
        else:
            m = metric_fn(gt, pred_mask)
            non_empty_ex.append(m)
            
    metrics = {
        "non_empty_ex": torch.tensor(non_empty_ex).mean().item(),
        "empty_ex": torch.tensor(empty_ex).mean().item(),
        "f1": loaded["f1"],
        "roc_auc": loaded["roc_auc"],
        "accuracy": loaded["accuracy"],
        "split": loaded["model_args"]["split"],
    }

    print(f"'non_empty_ex': '{metrics['non_empty_ex']:.4f}', 'empty_ex': '{metrics['empty_ex']:.4f}', 'f1': '{metrics['f1']:.4f}'")
    import csv
    if args.save_path is not None:
        file_exists = os.path.exists(args.save_path)
        with open(args.save_path, mode='a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=metrics.keys())
            if not file_exists:
                writer.writeheader()
            writer.writerow(metrics)


if __name__ == "__main__":
    main()
