import os
import os.path as osp
import torch
import torch_geometric.transforms as T
from torch_geometric.utils import to_undirected
import argparse

from greatx.attack.untargeted import Metattack, DICEAttack, PGDAttack, PRBCDAttack, GRBCDAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint
from torch_geometric.data import Data

import sys
sys.path.append("../")
from common import set_seed, load_inductive_graph_dataset_for_gnn

# Configuration
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data")
parser.add_argument("--graph_save_dir", type=str, default="/path/to/GraphAD_data/atkg")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--attack", type=str, default="dice")
parser.add_argument("--seeds", type=int, default=3)
parser.add_argument("--ptb_rate", type=float, default=0.20)
parser.add_argument("--emb_type", type=str, default="bow")
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--re_split", type=int, default=2)
args = parser.parse_args()

root_path = args.root_path
dataset_name = args.dataset
if args.dataset != 'arxiv':
    seeds = range(args.seeds)
else:
    # Use default split for arxiv
    seeds = [0]
    args.re_split = 0

ptb_rate = args.ptb_rate
attack_type = args.attack

device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
# Create save directory
os.makedirs(args.graph_save_dir, exist_ok=True)

# Generate attacks for each seed
for seed in seeds:
    set_seed(seed)
    # Load full dataset for features and labels
    full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(dataset_name, device, re_split=args.re_split, path_prefix=root_path, emb_model=args.emb_type, seed=seed)
    train_data = train_data.to(device)
    val_data = val_data.to(device)
    test_data = test_data.to(device)
    save_path = f"{args.graph_save_dir}/{dataset_name}/{attack_type}/{args.emb_type}_{int(ptb_rate*100)}_{seed}.pt"
    num_features = full_data.x.shape[-1]
    num_classes = full_data.y.max().item() + 1

    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    if os.path.exists(save_path):
        print(f"Attack file {save_path} already exists. Skipping.")
        continue

    # Train clean model using only training data
    cb = ModelCheckpoint(f'{dataset_name}_{attack_type}_{args.emb_type}_{int(ptb_rate*100)}_{seed}.pt', monitor='val_acc')
    model_before = GCN(num_features, num_classes, hids=[64])
    trainer_before = Trainer(model_before, device=device)
    trainer_before.reset_optimizer(lr=0.01, weight_decay=0)
    
    num_train_nodes = train_data.x.shape[0]
    num_val_nodes = val_data.x.shape[0] - num_train_nodes
    
    val_mask_for_val_data = torch.cat([
        torch.zeros(num_train_nodes, dtype=torch.bool, device=device),  # train nodes
        torch.ones(num_val_nodes, dtype=torch.bool, device=device)      # val nodes
    ])
    
    trainer_before.fit((train_data, val_data), mask=(None, val_mask_for_val_data), verbose=0, callbacks=[cb], epochs=200)
    
    attack_data = test_data
            
    if attack_type == "pgd":
        attacker = PGDAttack(attack_data, device=device)
        attacker.setup_surrogate(trainer_before.model, 
                               victim_nodes=torch.where(full_data.test_mask)[0],
                               ground_truth=True)
        attacker.reset()
        attacker.attack(ptb_rate)
        attacker.data().to(device)
    
    elif attack_type == "prbcd":
        attacker = PRBCDAttack(attack_data, device=device)
        attacker.setup_surrogate(trainer_before.model, 
                               victim_nodes=torch.where(full_data.test_mask)[0],
                               ground_truth=True)
        attacker.reset()
        attacker.attack(ptb_rate, block_size=1_000_000)
        attacker.data().to(device)
    
    elif attack_type == "grbcd":
        attacker = GRBCDAttack(attack_data, device=device)  
        attacker.setup_surrogate(trainer_before.model, 
                               victim_nodes=torch.where(full_data.test_mask)[0],
                               ground_truth=True)
        attacker.reset()
        attacker.attack(ptb_rate, block_size=1_000_000)
        attacker.data().to(device)

    # Save only the graph structure (edge_index) from the attacked graph
    attacked_data = attacker.data()
    attacked_data = attacked_data.to(device)  # Ensure on correct device

    #res = trainer_before.evaluate(attacked_data, mask=full_data.test_mask)
    #print(f"Attack accuracy: {res.acc}")
    torch.save(attacked_data.cpu().edge_index, save_path)