import os
import os.path as osp
import torch
import torch_geometric.transforms as T
from torch_geometric.utils import to_undirected
import argparse
import numpy as np
import torch.nn.functional as F
from datetime import datetime

from greatx.attack.untargeted import PGDAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN, GNNGUARD
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint

import sys
sys.path.append("../")
from common import set_seed, load_inductive_graph_dataset_for_gnn

class ThresholdPGDAttack(PGDAttack):
    """
    PGD Attack with cosine similarity threshold for perturbation masking.
    When threshold > 0, only edges between nodes with cosine similarity >= threshold can be perturbed.
    When threshold = 0, equivalent to original PGD attack.
    """
    
    def __init__(self, *args, threshold=0.1, **kwargs):
        super().__init__(*args, **kwargs)
        self.threshold = threshold
        
    def reset(self):
        """Reset and create cosine similarity-based perturbation mask."""
        super().reset()
        
        if self.threshold > 0:
            # Compute cosine similarity between all node pairs
            feat = self.feat
            feat_norm = F.normalize(feat, p=2, dim=1)  # L2 normalize features
            cosine_sim_matrix = torch.mm(feat_norm, feat_norm.T)  # N x N cosine similarity matrix
            
            # Create mask for edges that satisfy cosine similarity threshold
            cosine_mask = cosine_sim_matrix >= self.threshold
            
            # Combine with victim mask (only edges involving victim nodes can be perturbed)
            self.victim_mask = self.victim_mask & cosine_mask
            
        # For threshold = 0, use original victim_mask (no cosine similarity filtering)
        return self

def compute_edge_cosine_similarity(x, edge_index):
    """Compute cosine similarity for edges"""
    row, col = edge_index
    node_features_src = x[row]  # Source node features
    node_features_dst = x[col]  # Destination node features
    
    # Compute cosine similarity
    cosine_sim = F.cosine_similarity(node_features_src, node_features_dst, dim=1)
    return cosine_sim

def log_edge_modifications(original_edge_index, attacked_edge_index, x, log_file):
    """Log edge modifications and their cosine similarities"""
    
    # Convert to sets for comparison
    original_edges = set(map(tuple, original_edge_index.T.cpu().numpy()))
    attacked_edges = set(map(tuple, attacked_edge_index.T.cpu().numpy()))
    
    # Find added and removed edges
    added_edges = attacked_edges - original_edges
    removed_edges = original_edges - attacked_edges
    
    with open(log_file, 'a') as f:
        f.write(f"\n=== Edge Modifications ===\n")
        f.write(f"Original edges: {len(original_edges)}\n")
        f.write(f"Attacked edges: {len(attacked_edges)}\n")
        f.write(f"Added edges: {len(added_edges)}\n")
        f.write(f"Removed edges: {len(removed_edges)}\n")
        
        if added_edges:
            added_edge_tensor = torch.tensor(list(added_edges), device=x.device).T
            added_cosine_sim = compute_edge_cosine_similarity(x, added_edge_tensor)
            f.write(f"Added edges cosine similarity - Mean: {added_cosine_sim.mean().item():.4f}, "
                   f"Std: {added_cosine_sim.std().item():.4f}, "
                   f"Min: {added_cosine_sim.min().item():.4f}, "
                   f"Max: {added_cosine_sim.max().item():.4f}\n")
        
        if removed_edges:
            removed_edge_tensor = torch.tensor(list(removed_edges), device=x.device).T
            removed_cosine_sim = compute_edge_cosine_similarity(x, removed_edge_tensor)
            f.write(f"Removed edges cosine similarity - Mean: {removed_cosine_sim.mean().item():.4f}, "
                   f"Std: {removed_cosine_sim.std().item():.4f}, "
                   f"Min: {removed_cosine_sim.min().item():.4f}, "
                   f"Max: {removed_cosine_sim.max().item():.4f}\n")
        
        f.write("=" * 50 + "\n")

# Configuration
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data")
parser.add_argument("--graph_save_dir", type=str, default="/path/to/GraphAD_data/atkg")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--seeds", type=int, default=3)
parser.add_argument("--ptb_rate", type=float, default=0.20)
parser.add_argument("--emb_type", type=str, default="roberta")
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--re_split", type=int, default=2)
parser.add_argument("--atk_threshold", type=float, default=0.1, help="Cosine similarity threshold for edge perturbation (0.0 = original PGD)")
args = parser.parse_args()

root_path = args.root_path
dataset_name = args.dataset
if args.dataset != 'arxiv':
    seeds = range(args.seeds)
else:
    # Use default split for arxiv
    seeds = [0]
    args.re_split = 0

ptb_rate = args.ptb_rate
attack_type = "pgd"  # Only PGD attack

device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# Create save directory
os.makedirs(args.graph_save_dir, exist_ok=True)

# Define defense thresholds
def_thresholds = [0.3, 0.5, 0.7]

# Create log directory
log_dir = f"log_guard/{dataset_name}"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/{attack_type}guard_{args.atk_threshold}_{args.emb_type}_{int(ptb_rate*100)}.log"

# Initialize log file
with open(log_file, 'w') as f:
    f.write(f"PGD Attack with Cosine Similarity Threshold\n")
    f.write(f"Dataset: {dataset_name}\n")
    f.write(f"Embedding: {args.emb_type}\n")
    f.write(f"Perturbation Rate: {ptb_rate}\n")
    f.write(f"Cosine Similarity Attack Threshold: {args.atk_threshold}\n")
    if args.atk_threshold == 0.0:
        f.write("  -> Original PGD attack (no cosine similarity filtering)\n")
    f.write(f"Defense Thresholds: {def_thresholds}\n")
    f.write(f"Surrogate Model: GCN (hid=64)\n")
    f.write(f"Defense Model Hidden Dim: 128\n")
    f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write("=" * 80 + "\n")

# Generate attacks for each seed
for seed in seeds:
    set_seed(seed)
    # Load full dataset for features and labels
    full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(
        dataset_name, device, re_split=args.re_split, path_prefix=root_path, 
        emb_model=args.emb_type, seed=seed)
    
    train_data = train_data.to(device)
    val_data = val_data.to(device)
    test_data = test_data.to(device)
    
    # Create save path with guard-specific naming
    save_path = f"{args.graph_save_dir}/{dataset_name}/{attack_type}guard_{args.atk_threshold}/{args.emb_type}_{int(ptb_rate*100)}_{seed}.pt"
    num_features = full_data.x.shape[-1]
    num_classes = full_data.y.max().item() + 1

    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    if os.path.exists(save_path):
        print(f"Attack file {save_path} already exists. Skipping.")
        continue

    with open(log_file, 'a') as f:
        f.write(f"\n--- Seed {seed} ---\n")

    # Train GCN surrogate model using only training data (hidden dim 64)
    cb = ModelCheckpoint(f'gcn_surrogate_{dataset_name}_{args.atk_threshold}_{args.emb_type}_{int(ptb_rate*100)}_{seed}.pt', monitor='val_acc')
    surrogate_model = GCN(num_features, num_classes, hids=[64])
    trainer_surrogate = Trainer(surrogate_model, device=device)
    trainer_surrogate.reset_optimizer(lr=0.01, weight_decay=0)
    trainer_surrogate.fit((train_data, val_data), verbose=0, callbacks=[cb], epochs=200)
    
    # Log surrogate model performance
    surrogate_test_logs = trainer_surrogate.evaluate(test_data, mask=full_data.test_mask)
    with open(log_file, 'a') as f:
        f.write(f"GCN Surrogate Test Acc: {surrogate_test_logs['acc']:.4f}\n")
    
    # Attack using GCN surrogate with cosine similarity threshold
    victim_nodes_in_test_data = torch.where(full_data.test_mask)[0]
    
    attacker = ThresholdPGDAttack(test_data, device=device, threshold=args.atk_threshold)
    attacker.setup_surrogate(trainer_surrogate.model, 
                           victim_nodes=victim_nodes_in_test_data,
                           ground_truth=True)
    attacker.reset()
    attacker.attack(ptb_rate)
    
    # Get attacked data
    attacked_data = attacker.data().to(device)
    
    # Log edge modifications and cosine similarities
    log_edge_modifications(test_data.edge_index, attacked_data.edge_index, test_data.x, log_file)
    
    # Evaluate with GCN defense model (hidden dim 128)
    gcn_model = GCN(num_features, num_classes, hids=[128])
    trainer_gcn = Trainer(gcn_model, device=device)
    trainer_gcn.reset_optimizer(lr=0.01, weight_decay=0)
    
    # Train GCN on clean data
    cb_gcn = ModelCheckpoint(f'gcn_defense_{dataset_name}_{args.atk_threshold}_{args.emb_type}_{int(ptb_rate*100)}_{seed}.pt', monitor='val_acc')
    trainer_gcn.fit((train_data, val_data), verbose=0, callbacks=[cb_gcn], epochs=200)
    
    # Evaluate GCN on clean and attacked test data
    gcn_clean_logs = trainer_gcn.evaluate(test_data, mask=full_data.test_mask)
    gcn_attacked_logs = trainer_gcn.evaluate(attacked_data, mask=full_data.test_mask)
    
    # Evaluate with GNNGUARD defense models (different thresholds, hidden dim 128)
    guard_results = {}
    for def_threshold in def_thresholds:
        guard_model = GNNGUARD(num_features, num_classes, hids=[128], threshold=def_threshold)
        trainer_guard = Trainer(guard_model, device=device)
        trainer_guard.reset_optimizer(lr=0.01, weight_decay=0)
        
        # Train GNNGUARD on clean data
        cb_guard = ModelCheckpoint(f'guard_defense_{dataset_name}_{def_threshold}_{args.emb_type}_{int(ptb_rate*100)}_{seed}.pt', monitor='val_acc')
        trainer_guard.fit((train_data, val_data), verbose=0, callbacks=[cb_guard], epochs=200)
        
        # Evaluate GNNGUARD on clean and attacked test data
        guard_clean_logs = trainer_guard.evaluate(test_data, mask=full_data.test_mask)
        guard_attacked_logs = trainer_guard.evaluate(attacked_data, mask=full_data.test_mask)
        
        guard_results[def_threshold] = {
            'clean_acc': guard_clean_logs['acc'],
            'attacked_acc': guard_attacked_logs['acc']
        }
    
    # Log all results
    with open(log_file, 'a') as f:
        f.write(f"\n=== Defense Evaluation Results ===\n")
        f.write(f"GCN Defense (hid=128) - Clean Acc: {gcn_clean_logs['acc']:.4f}, Attacked Acc: {gcn_attacked_logs['acc']:.4f}\n")
        
        for def_threshold, results in guard_results.items():
            f.write(f"GNNGUARD Defense (def_threshold={def_threshold}, hid=128) - Clean Acc: {results['clean_acc']:.4f}, Attacked Acc: {results['attacked_acc']:.4f}\n")
        
        f.write(f"\n=== Robustness Ratios ===\n")
        # Compute robustness metrics
        gcn_robustness = gcn_attacked_logs['acc'] / gcn_clean_logs['acc'] if gcn_clean_logs['acc'] > 0 else 0
        f.write(f"GCN Robustness Ratio: {gcn_robustness:.4f}\n")
        
        for def_threshold, results in guard_results.items():
            guard_robustness = results['attacked_acc'] / results['clean_acc'] if results['clean_acc'] > 0 else 0
            f.write(f"GNNGUARD (def_threshold={def_threshold}) Robustness Ratio: {guard_robustness:.4f}\n")
        
        f.write("=" * 50 + "\n")
    
    # Save attacked graph
    torch.save(attacked_data.cpu().edge_index, save_path)
    print(f"Saved attacked graph to {save_path}")

print(f"All experiments completed. Results saved to {log_file}")