import os
import time
import datetime
import argparse
import numpy as np

import torch
import torch.nn.functional as F
# from torch.utils.tensorboard import SummaryWriter

from config import get_config
from utils import create_logger, seed_set
from utils import NoamLR, build_scheduler, build_optimizer, get_metric_func
from utils import load_checkpoint, save_best_checkpoint, load_best_result
from dataset import build_loader
from torch_geometric.data import Data
from loss import bulid_loss
from model import build_model


def parse_args():
    parser = argparse.ArgumentParser(description="codes for HiGNN")

    parser.add_argument(
        "--cfg",
        help="decide which cfg to use",
        required=False,
        default="../configs/bbbp.yaml",
        type=str,
    )

    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for training")
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--lr-scheduler', type=str, help='lr scheduler')
    parser.add_argument('--seed', type=int, default=123, help='random seed')
    parser.add_argument('--split', type=int, default=0, help='which data split to use')
    parser.add_argument('--num_layers', type=int, help='number of layers')
    parser.add_argument('--hidden_dim', type=int, help='hidden dimension')
    parser.add_argument('--dropout', type=float, help='dropout rate')
    parser.add_argument('--lr', type=float, help='learning rate')
    

    args = parser.parse_args()
    args.batch_size=1
    cfg = get_config(args)

    return args, cfg

def load_best_result(model, path):
    best_ckpt_path = path
    ckpt = torch.load(best_ckpt_path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt['model'])

    return model


def train(cfg, logger):
    seed_set(cfg.SEED)
    # step 1: dataloder loading, get number of tokens
    train_loader, val_loader, test_loader, weights = build_loader(cfg, logger)
    
    model = build_model(cfg)
    logger.info(model)
    # device mode
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    explanations=[]

    path = f"/pth model.pth"

    model = load_best_result(model, path)
    
    for batch in test_loader:
        batch = batch.to(device)
        model.eval()  
        output = model(batch)

        pred = output[0]
        att = output[1][0].detach().cpu().numpy()
        cross = output[1][1].detach().cpu().numpy()
        idx = cross[1]
        num = np.where(idx==0)[0]
        node_mask=att[num][batch.cluster_index.detach().cpu().numpy()]


        num_nodes = 0

        for b in range(batch.batch.max() + 1):
            x_mask = batch.batch == b
            
            # Handle empty edge_index case
            if batch.edge_index.numel() == 0:
                # No edges in the batch
                edge_index = torch.empty((2, 0), dtype=torch.long)
            else:
                edge_index_mask = batch.batch[batch.edge_index[0]] == b
                edge_index = batch.edge_index[:, edge_index_mask] - num_nodes
            
            data = Data(
                x=batch.x[x_mask].detach().cpu(),
                edge_index=edge_index.detach().cpu(),
                y=batch.y[b].detach().cpu(),
                pred=pred[b].detach().cpu(),
                node_mask=node_mask[x_mask.detach().cpu().numpy()]
            )
            explanations.append(data)

    print(f"Finished with {len(explanations)} explanations")
    torch.save(
        {
            "explanations": explanations,
        },
        f"/pth",
    )
    print(f"Saved explanations to /pth")


    

if __name__ == "__main__":
    _, cfg = parse_args()

    logger = create_logger(cfg)

    # print config
    logger.info(cfg.dump())
    # print device mode
    if torch.cuda.is_available():
        logger.info('GPU mode...')
    else:
        logger.info('CPU mode...')
    # training
    train(cfg, logger)


