import argparse
import copy
import json
import os
from tqdm import tqdm
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv import GINConv
import sys
import numpy as np
import pandas as pd
import csv
import os
import torch
from torch.utils.data import Subset
from config import get_config
from utils import create_logger, seed_set
import random
from torch_geometric.data import Data, Batch
from torch_geometric.explain import Explainer, GNNExplainer, PGExplainer
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj, subgraph
from config import get_config
from utils import create_logger, seed_set
from utils import NoamLR, build_scheduler, build_optimizer, get_metric_func
from utils import load_checkpoint, save_best_checkpoint, load_best_result
from dataset import build_loader
from torch_geometric.data import Data
from loss import bulid_loss
from model import build_model

from sklearn.metrics import f1_score, mean_squared_error as mse_score



def parse_args():
    parser = argparse.ArgumentParser(description="codes for HiGNN")

    parser.add_argument(
        "--cfg",
        help="decide which cfg to use",
        required=False,
        default="../configs/bbbp.yaml",
        type=str,
    )

    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for training")
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--lr-scheduler', type=str, help='lr scheduler')
    parser.add_argument('--seed', type=int, default=123, help='random seed')
    parser.add_argument('--split', type=int, default=0, help='which data split to use')
    parser.add_argument('--num_layers', type=int, help='number of layers')
    parser.add_argument('--hidden_dim', type=int, help='hidden dimension')
    parser.add_argument('--dropout', type=float, help='dropout rate')
    parser.add_argument('--lr', type=float, help='learning rate')
    parser.add_argument("--explanations_path", type=str,default="explanations path" ,required=False)
    parser.add_argument("--percentage", type=float, default=0.1, help="percentage of nodes to mask")

    args = parser.parse_args()
    args.batch_size=1
    cfg = get_config(args)

    return args, cfg


class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, edge_index, batch):
        data = Data(x=x, edge_index=edge_index, batch=batch)
        return self.model(data)[0]


def load_best_result(model, path):
    best_ckpt_path = path
    ckpt = torch.load(best_ckpt_path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt['model'])

    return model
def set_seed(seed=2021):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def train(cfg, logger, percentage, explanations_path):
    set_seed(cfg.SEED)
    # step 1: dataloder loading, get number of tokens
    train_loader, val_loader, test_loader, weights = build_loader(cfg, logger)
    
    model = build_model(cfg)
    logger.info(model)
    # device mode
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    path = f"/pth model.pth"

    model = load_best_result(model, path)
    if cfg.DATA.DATASET=="sol":
        absolute=1
        mask_contribution=0
    else:
        absolute=0
        mask_contribution=1
    number_masked=torch.load(f"number masked contributions pt")

    pos_fidelitys, neg_fidelitys, accs = list(), list(), list()
    non_empty_ex, empty_ex = list(), list()
    loaded = torch.load(explanations_path, weights_only=False)
    explanations = loaded['explanations']
    task=cfg.DATA.TASK_TYPE
    model.eval()
    # accs=[]
    s=[]
    for ids, (e, data) in tqdm(enumerate(zip(explanations, test_loader)), total=len(explanations)):
        # print(data.x)
        # print(e.x)
        assert torch.equal(e.x, data.x)
        number_to_mask=number_masked[ids].item()
        data = data.to(device)
        e["node_mask"]=torch.tensor(e["node_mask"])
        try:
            topk_vals, topk_indices = torch.topk(e["node_mask"], number_to_mask)
        except Exception as exp:
            topk_indices = torch.arange(e["node_mask"].shape[0], device=device)
        
        mask = torch.zeros_like(e["node_mask"], dtype=torch.bool, device=device)
        mask[topk_indices] = True
        
        s.append(len(topk_indices))

        mask = mask.to(device)
        
        inverse_mask = ~mask
        inverse_mask = inverse_mask.to(device)
        original_data = copy.deepcopy(data)
        data_with_mask = copy.deepcopy(data)
        data_with_mask.x = data_with_mask.x * mask.unsqueeze(-1)
        
        data_with_inverse_mask = copy.deepcopy(data)
        data_with_inverse_mask.x = data_with_inverse_mask.x * inverse_mask.unsqueeze(-1)
        
        original_output = model(data)[0]
        
        masked_output = model(data_with_mask)[0]

        inverse_masked_output = model(data_with_inverse_mask)[0]

        if task == "classification":
            original_output = original_output.argmax(dim=-1)
            masked_output = masked_output.argmax(dim=-1)
            inverse_masked_output = inverse_masked_output.argmax(dim=-1)
            metric_ = (data.y == original_output).float().mean()
            pos_fidelity = ((data.y == original_output).float() -
                            (inverse_masked_output == data.y).float()).abs().mean()
            neg_fidelity = ((data.y == original_output).float() -
                            (masked_output == data.y).float()).abs().mean()
        else:
            original_output = original_output.squeeze()
            masked_output = masked_output.squeeze()
            inverse_masked_output = inverse_masked_output.squeeze()
            pos_fidelity = torch.abs(inverse_masked_output - original_output)
            neg_fidelity = torch.abs(masked_output - original_output)
            metric_= abs(data.y - original_output).float().mean()
        pos_fidelitys.append(pos_fidelity.item())
        accs.append(metric_.item())
        neg_fidelitys.append(neg_fidelity.item())
    
    metrics = {
        "dataset": cfg.DATA.DATASET,
        "explainer_type": "HIGNN",
        "split": cfg.DATA.SPLIT,
        "pos_fidelity": torch.tensor(pos_fidelitys).mean().item(),
        "neg_fidelity": torch.tensor(neg_fidelitys).mean().item(),
        "metric_calc":np.mean(np.array(accs)).item(),
        "percentage": args.percentage,
    }
    print(f"Metric calculated: {metrics['metric_calc']}")
    print(f"Positive Fidelity: {metrics['pos_fidelity']}")
    print(f"Negative Fidelity: {metrics['neg_fidelity']}")
    save_path = "path to save results"
    if save_path is not None:
        file_exists = os.path.exists(save_path)
        with open(save_path, mode='a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=metrics.keys())
            if not file_exists:
                writer.writeheader()
            writer.writerow(metrics)
        

if __name__ == "__main__":
    args, cfg = parse_args()

    logger = create_logger(cfg)

    # print config
    logger.info(cfg.dump())
    # print device mode
    if torch.cuda.is_available():
        logger.info('GPU mode...')
    else:
        logger.info('CPU mode...')
    # training
    train(cfg, logger, args.percentage, args.explanations_path)

