import os
import os.path as osp
import torch
import torch_geometric.transforms as T
from torch_geometric.utils import to_undirected
import argparse
import numpy as np
import json
from datetime import datetime

from greatx.attack.untargeted import Metattack, DICEAttack, PGDAttack, PRBCDAttack, GRBCDAttack, MetaApprox, STRGAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint
from torch_geometric.data import Data

import sys
sys.path.append("../")
from common import set_seed, load_graph_dataset_for_gnn

# Configuration
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data")
parser.add_argument("--graph_save_dir", type=str, default="/path/to/GraphAD_data/atkg")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--attack", type=str, default="dice")
parser.add_argument("--seeds", type=int, default=3)
parser.add_argument("--ptb_rate", type=float, default=0.20)
parser.add_argument("--emb_type", type=str, default="bow")
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--re_split", type=int, default=1)
parser.add_argument("--threshold", type=float, default=0.5)

args = parser.parse_args()

root_path = args.root_path
dataset_name = args.dataset
seeds = range(args.seeds)
ptb_rate = args.ptb_rate
attack_type = args.attack

device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# Create save directory
os.makedirs(args.graph_save_dir, exist_ok=True)

def save_results(args, clean_accs, no_valid_accs, valid_accs):
    results = {
        "args": vars(args),
        "clean": {
            "mean": round(float(np.mean(clean_accs)) * 100, 2),
            "std": round(float(np.std(clean_accs)) * 100, 2),
            "all_runs": [round(float(acc) * 100, 2) for acc in clean_accs]
        },
        "no_valid": {
            "mean": round(float(np.mean(no_valid_accs)) * 100, 2),
            "std": round(float(np.std(no_valid_accs)) * 100, 2),
            "all_runs": [round(float(acc) * 100, 2) for acc in no_valid_accs]
        },
        "valid": {
            "mean": round(float(np.mean(valid_accs)) * 100, 2),
            "std": round(float(np.std(valid_accs)) * 100, 2),
            "all_runs": [round(float(acc) * 100, 2) for acc in valid_accs]
        }
    }
    
    log_dir = os.path.join("./logs", args.dataset, args.attack)
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_path = os.path.join(log_dir, f"results_{args.threshold}_{int(args.ptb_rate*100)}_{timestamp}.json")
    
    with open(log_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    print("\nResults Summary:")
    print(f"Clean - Acc: {results['clean']['mean']:.4f} ± {results['clean']['std']:.4f}")
    print(f"No Valid - Acc: {results['no_valid']['mean']:.4f} ± {results['no_valid']['std']:.4f}")
    print(f"Valid - Acc: {results['valid']['mean']:.4f} ± {results['valid']['std']:.4f}")

# Generate attacks for each seed
clean_accs = []
no_valid_accs = []
valid_accs = []

for seed in seeds:
    set_seed(seed)
    full_data = load_graph_dataset_for_gnn(dataset_name, device, re_split=args.re_split, path_prefix=root_path, emb_model=args.emb_type)
    full_data = full_data.to(device)
    num_features = full_data.x.shape[-1]
    num_classes = full_data.y.max().item() + 1
    save_path = f"{args.graph_save_dir}/{dataset_name}/{attack_type}/{args.emb_type}_{int(ptb_rate*100)}_{seed}.pt"
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    if os.path.exists(save_path):
        print(f"Attack file {save_path} already exists. Skipping.")
        continue

    # Train clean model using only training data
    model = GCN(num_features, num_classes, hids=[64])
    trainer_before = Trainer(model, device=device)
    trainer_before.reset_optimizer(lr=0.01, weight_decay=5e-4)
    ckp = ModelCheckpoint(f'Before_{dataset_name}_{attack_type}_{args.emb_type}_{args.threshold}_{int(ptb_rate*100)}_{seed}.pt', monitor='val_acc')
    trainer_before.fit(full_data, mask=(full_data.train_mask, full_data.val_mask), verbose=0, epochs=200, callbacks=[ckp])
    print(f"Generating {attack_type} attack with seed {seed}...")
    test_logs = trainer_before.evaluate(full_data, mask=full_data.test_mask)
    clean_acc = test_logs['acc']
    clean_accs.append(clean_acc)
    print(f"Before attack test acc: {clean_acc:.4f}")
    
    unlabeled_mask = full_data.test_mask
    
    if attack_type == "mettack":
        num_nodes = full_data.x.shape[0] 
        attacker = Metattack(full_data, device=device)
        attacker.setup_surrogate(trainer_before.model,
                                labeled_nodes=torch.where(full_data.train_mask)[0],
                                unlabeled_nodes=torch.where(unlabeled_mask)[0],
                                lambda_=0., ground_truth=True)
        attacker.reset()
        attacker.attack(ptb_rate, disable=False)
        
    elif attack_type == "dice":
        attacker = DICEAttack(full_data, device=device)
        attacker.reset()
        attacker.attack(ptb_rate, threshold=0.5)
    
    elif attack_type == "strg":
        attacker = STRGAttack(full_data, device=device)
        attacker.reset()
        attacker.attack(ptb_rate, train_mask=full_data.train_mask, val_mask=full_data.val_mask, test_mask=full_data.test_mask,
                        threshold=args.threshold)

    # Save only the graph structure (edge_index) from the attacked graph
    attacked_data = attacker.data()
    #torch.save(attacked_data.edge_index, save_path)

    # Evaluate the attack
    model.reset_parameters()
    trainer_after = Trainer(model, device=device)
    trainer_after.fit(attacked_data, mask=(attacked_data.train_mask), verbose=0, epochs=200)
    test_logs = trainer_after.evaluate(attacked_data, mask=attacked_data.test_mask)
    no_valid_acc = test_logs['acc']
    no_valid_accs.append(no_valid_acc)
    print(f"Attack {attack_type} with seed {seed} completed. No validation. Test Acc: {no_valid_acc:.4f}")

    model.reset_parameters()
    cb = ModelCheckpoint(f'After_{dataset_name}_{attack_type}_{args.emb_type}_{args.threshold}_{int(ptb_rate*100)}_{seed}.pt', monitor='val_acc')
    trainer_after = Trainer(model, device=device)
    trainer_after.fit(attacked_data, mask=(attacked_data.train_mask, attacked_data.val_mask), verbose=0, callbacks=[cb], epochs=200)
    test_logs = trainer_after.evaluate(attacked_data, mask=attacked_data.test_mask)
    valid_acc = test_logs['acc']
    valid_accs.append(valid_acc)
    print(f"Attack {attack_type} with seed {seed} completed. With validation. Test Acc: {valid_acc:.4f}")

# Save final results
save_results(args, clean_accs, no_valid_accs, valid_accs)