import argparse
import copy
import os
from tqdm import tqdm

import numpy as np

import torch

from torch_geometric.loader import DataLoader

from sklearn.metrics import roc_auc_score, average_precision_score

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from datasetss.dataset import SyntheticGraphFeaturizer, build_dataset


def args_parser():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--seed", type=int, default=123, help="seed")
    parser.add_argument("--save_path", type=str, required=True, help="path to save results")
    parser.add_argument("--explanations_path", type=str, required=True, help="path to explanations")
    parser.add_argument("--mask_to_eval", type=str, default="node", choices=["node", "edge"])
    args = parser.parse_args()
    return args




def get_node_importance(e, data):
    node_mask = e.node_mask.detach() if hasattr(e, "node_mask") else None
    node_mask = torch.nan_to_num(node_mask)
    if len(node_mask.shape) == 2:
        node_mask = node_mask.sum(-1)
    assert len(node_mask.shape) == 1
    return node_mask


def get_edge_importance(e, data):
    edge_mask = e.edge_mask.detach() if hasattr(e, "edge_mask") else None
    edge_mask = torch.nan_to_num(edge_mask)
    assert len(edge_mask.shape) == 1
    return edge_mask


def no_outliers(x):
    Q1 = torch.quantile(x, 0.25)
    Q3 = torch.quantile(x, 0.75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    outliers = x[(x < lower_bound) | (x > upper_bound)]
    return float(len(outliers) == 0)


def main():
    args = args_parser()
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)

    loaded = torch.load(args.explanations_path, weights_only=False)
    task = loaded["model_args"]["task"]


    dataset_kwargs = {
        "data_set": loaded["model_args"]["data_set"],
        "mean": 0.0,
        "std": 1.0,
        "y_column": 'Y',
        "smiles_col": "Drug",
        "task":task,
        "split": loaded["model_args"]["split"],
    }
    _,_, dataset_test = build_dataset(dataset_kwargs)
    featurizer=SyntheticGraphFeaturizer(y_column='Y')

    test_set = featurizer(dataset_test, dataset_kwargs)
    dataloader_test = DataLoader(test_set, batch_size=1, shuffle=False)
    explanations = loaded["explanations"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    non_empty_ex, empty_ex = list(), list()

    metric_fn = roc_auc_score
    print(f"Task: {task} Eval metric: {metric_fn.__name__}")

    non_empty_ex, empty_ex = list(), list()

    for e, data in tqdm(zip(explanations, dataloader_test), total=len(explanations)):
        assert torch.equal(data.x, e.x) and torch.equal(data.edge_index, e.edge_index), "Data and explanation do not match"
        assert data.y.item()==e.y.item(), "Data and explanation do not match"
        assert data.explanation.reshape(-1).shape[0] == e.node_mask.reshape(-1).shape[0], "Data and explanation do not match"
        pred_mask = get_node_importance(e, data)
        gt = data.explanation

        if gt.min() == gt.max():
            m = no_outliers(pred_mask)
            empty_ex.append(m)
        else:
            m = metric_fn(gt, pred_mask)
            non_empty_ex.append(m)
            
    metrics = {
        "non_empty_ex": torch.tensor(non_empty_ex).mean().item(),
        "empty_ex": torch.tensor(empty_ex).mean().item(),
        "f1": loaded["f1"],
        "roc_auc": loaded["roc_auc"],
        "accuracy": loaded["accuracy"],
        "split": loaded["model_args"]["split"],
    }

    print(f"'non_empty_ex': '{metrics['non_empty_ex']:.4f}', 'empty_ex': '{metrics['empty_ex']:.4f}', 'f1': '{metrics['f1']:.4f}'")
    import csv
    if args.save_path is not None:
        file_exists = os.path.exists(args.save_path)
        with open(args.save_path, mode='a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=metrics.keys())
            if not file_exists:
                writer.writeheader()
            writer.writerow(metrics)


if __name__ == "__main__":
    main()
