import argparse
import copy
import os
from tqdm import tqdm
from torch_geometric.nn.conv import MessagePassing
import numpy as np
import pandas as pd
import csv
import os
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from datasetss.dataset import GraphFeaturizer, build_dataset
from train_model import get_model


def args_parser():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--seed", type=int, default=123, help="seed")
    parser.add_argument("--save_path", type=str, required=True, help="path to save results")
    parser.add_argument("--explanations_path", type=str ,required=True)
    parser.add_argument("--percentage", type=float, default=0.1, help="percentage of nodes to mask")
    args = parser.parse_args()
    return args

class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, edge_index, batch):
        data = Data(x=x, edge_index=edge_index, batch=batch)
        return self.model(data)[0]

def main():
    args = args_parser()
    print(args)
    seed = args.seed 
    torch.manual_seed(seed)
    np.random.seed(seed)

    loaded = torch.load(args.explanations_path, weights_only=False)
    task = loaded["model_args"]["task"]

    if loaded["model_args"]['data_set'] == "sol":
        dataset_kwargs = {
            "data_set": loaded["model_args"]["data_set"],
            "mean": -2.86,
            "std": 2.38,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": task,
            "split": loaded["model_args"]["split"],
        }
    else:
        dataset_kwargs = {
            "data_set": loaded["model_args"]["data_set"],
            "mean": 0.0,
            "std": 1.0,
            "y_column": 'Y',
            "smiles_col": "Drug",
            "task": task,
            "split": loaded["model_args"]["split"],
        }
    _,_, dataset_test = build_dataset(dataset_kwargs)
    featurizer=GraphFeaturizer(y_column='Y')

    test_set = featurizer(dataset_test, dataset_kwargs)
    dataloader_test = DataLoader(test_set, batch_size=1, shuffle=False)
    explanations = loaded["explanations"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    loaded_model= torch.load(loaded["args"]['model_path'], map_location=torch.device("cpu"), weights_only=False)
 
    try:
        model = get_model(**loaded_model["model_args"]).to(device)
        model.load_state_dict(loaded_model["state_dict"])
        num_layers = loaded_model["model_args"]["num_layers"]
        # Setting up params used by an explainer
        for module in model.modules():
            if isinstance(module, MessagePassing):
                if not hasattr(module, "in_channels"):
                    print(module.__dict__)
                    channel_list = module.nn.channel_list
                    module.in_channels = channel_list[0]
                    module.out_channels = channel_list[-1]
    except Exception as e:
        print(f"Error loading model: {e}")
        return
    model.eval()
    pos_fidelity_models, neg_fidelity_models = list(), list()
    # accs=[]
    s=[]
    for e, data in tqdm(zip(explanations, dataloader_test), total=len(explanations)):
        assert torch.equal(e.x, data.x)

        data = data.to(device)
        
        try:
            topk_vals, topk_indices = torch.topk((e["node_mask"]), int(data.x.shape[0] * args.percentage))
        except Exception as exp:
            topk_indices = torch.arange(e["node_mask"].shape[0], device=device)
        s.append(len(topk_indices))

        mask = torch.zeros_like(e["node_mask"], dtype=torch.bool, device=device)
        mask[topk_indices] = True
        

        mask = mask.to(device)
        
        inverse_mask = ~mask
        inverse_mask = inverse_mask.to(device)
        original_data = copy.deepcopy(data)
        data_with_mask = copy.deepcopy(data)
        data_with_mask.x = data_with_mask.x * mask.unsqueeze(-1)
        
        data_with_inverse_mask = copy.deepcopy(data)
        data_with_inverse_mask.x = data_with_inverse_mask.x * inverse_mask.unsqueeze(-1)
        
        original_output = model(data.x, data.edge_index, batch=data.batch)
        
        masked_output = model(data_with_mask.x, data_with_mask.edge_index, batch=data_with_mask.batch)

        inverse_masked_output = model(data_with_inverse_mask.x, data_with_inverse_mask.edge_index, batch=data_with_inverse_mask.batch)

    

        if task == "classification":
            original_output = original_output.argmax(dim=-1)
            masked_output = masked_output.argmax(dim=-1)
            inverse_masked_output = inverse_masked_output.argmax(dim=-1)
            pos_fidelity_model= (inverse_masked_output == original_output).float().mean()
            neg_fidelity_model = (masked_output == original_output).float().mean()
            
            
        else:
            original_output = original_output.squeeze()
            masked_output = masked_output.squeeze()
            inverse_masked_output = inverse_masked_output.squeeze()
            pos_fidelity_model = torch.abs(inverse_masked_output - original_output)
            neg_fidelity_model = torch.abs(masked_output - original_output)

        pos_fidelity_models.append(pos_fidelity_model.item())
        neg_fidelity_models.append(neg_fidelity_model.item())

    if task == "classification":
        pos_fidelity_model=1-torch.tensor(pos_fidelity_models).mean().item()
        neg_fidelity_model=1-torch.tensor(neg_fidelity_models).mean().item()
    else:
        pos_fidelity_model=torch.tensor(pos_fidelity_models).mean().item()
        neg_fidelity_model=torch.tensor(neg_fidelity_models).mean().item()

    metrics = {
        "data_set": loaded["model_args"]["data_set"],
        "explainer_type": loaded["args"]["explainer_type"],
        "split": loaded["model_args"]["split"],
        "pos_fidelity_models": pos_fidelity_model,
        "neg_fidelity_models": neg_fidelity_model,
        "percentage": args.percentage,
    }
    print(f"Positive Fidelity Model: {metrics['pos_fidelity_models']}")
    print(f"Negative Fidelity Model: {metrics['neg_fidelity_models']}")
    
    if args.save_path is not None:
        file_exists = os.path.exists(args.save_path)
        with open(args.save_path, mode='a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=metrics.keys())
            if not file_exists:
                writer.writeheader()
            writer.writerow(metrics)
        

if __name__ == "__main__":
    main()
