import argparse
import copy
import json
import os
from tqdm import tqdm

import numpy as np
import pandas as pd
import csv
import os
import torch
from torch.utils.data import Subset

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj, subgraph

from sklearn.metrics import f1_score, mean_squared_error as mse_score
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from datasetss.dataset import GraphFeaturizer, build_dataset
from architectures.SEAL import SEALNetwork

def create_model(kwargs):
    models = {
        "SEAL": SEALNetwork,
    }
    model_name = kwargs.get("model", "")
    if model_name not in models:
        raise ValueError(f"Unknown model name: {model_name}")


    return models[model_name](kwargs)

def args_parser():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--seed", type=int, default=123, help="seed")
    parser.add_argument("--save_path", type=str,required=True, help="path to csv save results")
    parser.add_argument("--explanations_path", type=str, required=True, help="path to explanations")
    parser.add_argument("--percentage", type=float, default=0.1, help="percentage of nodes to mask")
    args = parser.parse_args()
    return args

class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, edge_index, batch):
        data = Data(x=x, edge_index=edge_index, batch=batch)
        return self.model(data)[0]

def main():
    args = args_parser()
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)
    non_abs=False
    mask_contr=False
    
    loaded = torch.load(args.explanations_path, weights_only=False)
    

    task = loaded["model_args"]["task"]

    if loaded["model_args"]['data_set'] == "sol":
        dataset_kwargs = {
            "data_set": loaded["model_args"]["data_set"],
            "mean": -2.86,
            "std": 2.38,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": "regression",
            "split": loaded["model_args"]["split"],
        }
    else:
        dataset_kwargs = {
            "data_set": loaded["model_args"]["data_set"],
            "mean": 0.0,
            "std": 1.0,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": "classification",
            "split": loaded["model_args"]["split"],
        }
    _,_, dataset_test = build_dataset(dataset_kwargs)
    featurizer=GraphFeaturizer(y_column='Y')

    test_set = featurizer(dataset_test, dataset_kwargs)
    dataloader_test = DataLoader(test_set, batch_size=1, shuffle=False)
    explanations = loaded["explanations"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    pos_fidelity_models, neg_fidelity_models = list(), list()
    
    
    model_weights = torch.load(loaded["args"]["explainer_path"], map_location=torch.device(device))['state_dict']
    model_kwargs = loaded["explainer"]
    model = create_model(model_kwargs)
    model.load_state_dict(model_weights)
    model = model.to(device)
    model.eval()
    # accs=[]
    s=[]
    for e, data in tqdm(zip(explanations, dataloader_test), total=len(explanations)):
        assert torch.equal(e.x, data.x)

        data = data.to(device)
        
        node_mask = e["node_mask"]
        # abs_mask = torch.abs(node_mask)
        if non_abs:
            abs_mask = node_mask
        # abs_mask=node_mask
        else:
            abs_mask = torch.abs(node_mask)

        original_data = copy.deepcopy(data)
        original_output = model(data, mask_idx=None)["output"]

        # max_val = torch.max(abs_mask)
        i=1
        while True:
            tolerance = 1e-3
            # max_val is a tensor of shape [i]
            # max_val = torch.topk(torch.unique(abs_mask), i)[0]
            
            if non_abs:
                if original_output >= 0:
                    max_val = torch.topk(torch.unique(abs_mask), i)[0]
                else:
                    max_val = torch.topk(torch.unique(abs_mask), i, largest=False)[0]
            else:
                max_val = torch.topk(torch.unique(abs_mask), i)[0]
            
            diff = torch.abs(abs_mask.unsqueeze(-1) - max_val) 
            is_close = (diff < tolerance).any(dim=1)            
            topk_indices = is_close.nonzero(as_tuple=True)[0]                
            if len(topk_indices)/ data.x.shape[0] >= args.percentage:
                break
            i += 1            
        
        s.append(len(topk_indices))
        top_k_clusters = torch.unique(data.s[topk_indices].argmax(-1)).detach().cpu().tolist()
        
        
        # Create mask
        mask = torch.zeros_like(node_mask, dtype=torch.bool, device=device)
        mask[topk_indices] = True

        mask = mask.to(device)
        
        inverse_mask = ~mask
        inverse_mask = inverse_mask.to(device)
        data_with_mask = copy.deepcopy(data)
        data_with_mask.x = data_with_mask.x * mask.unsqueeze(-1)
        
        data_with_inverse_mask = copy.deepcopy(data)
        data_with_inverse_mask.x = data_with_inverse_mask.x * inverse_mask.unsqueeze(-1)
        
        
        if mask_contr:
            masked_output = model(data_with_mask, mask_idx=None)["output"]
            inverse_masked_output = model(data_with_inverse_mask, mask_idx=None)["output"]
        else:
            masked_output = model(data_with_mask, mask_idx=top_k_clusters)["output"]
            inverse_masked_output = model(data_with_inverse_mask, mask_idx=list(set(range(data.num_cluster.item())) - set(top_k_clusters)))["output"]
        
        if task == "classification":
            original_output = original_output>=0
            masked_output = masked_output>=0
            inverse_masked_output = inverse_masked_output>=0            
            pos_fidelity_model= (inverse_masked_output == original_output).float().mean()
            neg_fidelity_model = (masked_output == original_output).float().mean()
            
        else:
            original_output = original_output.squeeze()
            masked_output = masked_output.squeeze()
            inverse_masked_output = inverse_masked_output.squeeze()
            pos_fidelity_model = torch.abs(inverse_masked_output - original_output)
            neg_fidelity_model = torch.abs(masked_output - original_output)


        # neg_fidelitys.append(neg_fidelity.item())
        pos_fidelity_models.append(pos_fidelity_model.item())
        neg_fidelity_models.append(neg_fidelity_model.item())



    if task == "classification":
        pos_fidelity_model=1-torch.tensor(pos_fidelity_models).mean().item()
        neg_fidelity_model=1-torch.tensor(neg_fidelity_models).mean().item()
    else:
        pos_fidelity_model=torch.tensor(pos_fidelity_models).mean().item()
        neg_fidelity_model=torch.tensor(neg_fidelity_models).mean().item()

    metrics = {
        "data_set": loaded["model_args"]["data_set"],
        "explainer_type": loaded["args"]["explainer_type"],
        "split": loaded["model_args"]["split"],
        "pos_fidelity_models": pos_fidelity_model,
        "neg_fidelity_models": neg_fidelity_model,
        "percentage": args.percentage,
    }
    print(f"Positive Fidelity Model: {metrics['pos_fidelity_models']}")
    print(f"Negative Fidelity Model: {metrics['neg_fidelity_models']}")
    
    if args.save_path is not None:
        file_exists = os.path.exists(args.save_path)
        with open(args.save_path, mode='a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=metrics.keys())
            if not file_exists:
                writer.writeheader()
            writer.writerow(metrics)
        

if __name__ == "__main__":
    main()
