import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import os
import numpy as np
import argparse
import logging
from datetime import datetime
from itertools import product

from torch_geometric.utils import to_undirected
from torch_geometric.data import Data

from greatx.nn.models import GCN
from greatx.nn.layers import GCNConv, Sequential, activations
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping
from greatx.utils import wrapper

import sys
sys.path.append("../")
from common import set_seed, load_inductive_graph_dataset_for_gnn

import yaml


class AutoGCN(nn.Module):
    """GCN variant that simulates auto behavior:
    1) Training: Add text_attacked class for attack detection
    2) Test: Predict if nodes are attacked, apply filtering accordingly
    """
    
    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: list = [16], acts: list = ['relu'],
                 dropout: float = 0.5, bias: bool = True, 
                 normalize: bool = True, similarity_threshold: float = 0.5):
        super().__init__()
        
        self.original_out_channels = out_channels
        # Add one extra class for text_attacked detection during training
        self.attack_out_channels = out_channels + 1
        self.similarity_threshold = similarity_threshold
        
        # Build GCN layers
        conv = []
        assert len(hids) == len(acts)
        for hid, act in zip(hids, acts):
            conv.append(GCNConv(in_channels, hid, bias=bias, normalize=normalize))
            conv.append(activations.get(act))
            conv.append(nn.Dropout(dropout))
            in_channels = hid
        
        # Final layer that can output both normal classes + text_attacked
        self.conv_layers = Sequential(*conv)
        self.classifier = GCNConv(in_channels, self.attack_out_channels, bias=bias, normalize=normalize)
        
    def reset_parameters(self):
        self.conv_layers.reset_parameters()
        self.classifier.reset_parameters()
        
    def filter_edges_by_similarity(self, x, edge_index):
        """Apply similarity filtering to edges"""
        if edge_index.shape[1] == 0:
            return edge_index
            
        # Get source and target node indices  
        src_nodes = edge_index[0]
        tgt_nodes = edge_index[1]
        
        # Get embeddings for source and target nodes
        src_embs = x[src_nodes] 
        tgt_embs = x[tgt_nodes]
        
        # Compute cosine similarity
        similarities = torch.cosine_similarity(src_embs, tgt_embs, dim=1)
        
        # Filter edges where similarity >= threshold
        keep_mask = similarities >= self.similarity_threshold
        filtered_edge_index = edge_index[:, keep_mask]
        
        return filtered_edge_index
        
    def forward(self, x, edge_index, edge_weight=None, training=True, apply_auto_logic=True, return_stats=False):
        """Forward pass with auto-like behavior during inference"""
        
        if training or not apply_auto_logic:
            # During training: standard forward pass
            h = self.conv_layers(x, edge_index, edge_weight)
            result = self.classifier(h, edge_index, edge_weight)
            if return_stats:
                return result, {}
            return result
        
        else:
            # During inference: apply auto-like 3-stage process
            stats = {}
            
            # Stage 1: Attack detection
            h = self.conv_layers(x, edge_index, edge_weight)
            stage1_logits = self.classifier(h, edge_index, edge_weight)
            stage1_pred = torch.argmax(stage1_logits, dim=1)
            
            # Identify text-attacked nodes (last class index)
            text_attacked_mask = (stage1_pred == self.attack_out_channels - 1)
            text_attacked_indices = text_attacked_mask.nonzero(as_tuple=False).squeeze(-1)
            
            stats['text_attacked_detected'] = len(text_attacked_indices)
            stats['total_nodes'] = x.shape[0]
            
            # Stage 2: For text-attacked nodes, mask their features and re-predict
            text_recovery_used = 0
            if len(text_attacked_indices) > 0:
                masked_x = x.clone()
                masked_x[text_attacked_indices] = 0  # Mask text-attacked node features
                
                h_masked = self.conv_layers(masked_x, edge_index, edge_weight) 
                stage2_logits = self.classifier(h_masked, edge_index, edge_weight)
                
                # Update predictions for text-attacked nodes, only use original classes
                stage1_logits[text_attacked_indices, :-1] = stage2_logits[text_attacked_indices, :-1]
                stage1_logits[text_attacked_indices, -1] = -float('inf')  # Remove text_attacked class
                text_recovery_used = len(text_attacked_indices)
            
            stats['text_recovery_used'] = text_recovery_used
            
            # Stage 3: Filter edges for normal nodes (remove text_attacked edges + similarity filter)
            original_edges = edge_index.shape[1]
            
            # First remove edges involving detected text_attacked nodes (harmful to normal nodes)
            if len(text_attacked_indices) > 0:
                text_attacked_set = set(text_attacked_indices.cpu().numpy().tolist())
                edge_mask = ~torch.tensor([
                    (src.item() in text_attacked_set) or (tgt.item() in text_attacked_set)
                    for src, tgt in zip(edge_index[0], edge_index[1])
                ], device=edge_index.device)
                edges_without_text_attacked = edge_index[:, edge_mask]
            else:
                edges_without_text_attacked = edge_index
            
            # Then apply similarity filtering to remaining edges
            filtered_edge_index = self.filter_edges_by_similarity(x, edges_without_text_attacked)
            filtered_edges = filtered_edge_index.shape[1]
            
            stats['original_edges'] = original_edges
            stats['filtered_edges'] = filtered_edges
            stats['edges_removed'] = original_edges - filtered_edges
            
            # Re-compute with double-filtered edges for non-text-attacked nodes
            structure_recovery_used = 0
            if filtered_edge_index.shape[1] != edge_index.shape[1]:
                h_filtered = self.conv_layers(x, filtered_edge_index, edge_weight)
                stage3_logits = self.classifier(h_filtered, filtered_edge_index, edge_weight)
                
                # For non-text-attacked nodes, use filtered result
                non_text_attacked_mask = ~text_attacked_mask
                stage1_logits[non_text_attacked_mask, :-1] = stage3_logits[non_text_attacked_mask, :-1]
                stage1_logits[non_text_attacked_mask, -1] = -float('inf')  # Remove text_attacked class
                structure_recovery_used = non_text_attacked_mask.sum().item()
            
            stats['structure_recovery_used'] = structure_recovery_used
            
            # Return only original classes (remove text_attacked class)
            final_logits = stage1_logits[:, :-1]
            
            if return_stats:
                return final_logits, stats
            return final_logits


def create_training_data_with_attacks(train_data, attack_ratio=0.15):
    """Create training data with text_attacked samples"""
    num_nodes = train_data.x.shape[0]
    num_classes = train_data.y.max().item() + 1
    
    # Create augmented labels (add text_attacked class)
    augmented_y = train_data.y.clone()
    
    # Select nodes to mark as text_attacked
    num_attack_samples = max(1, int(num_nodes * attack_ratio))
    attack_indices = torch.randperm(num_nodes)[:num_attack_samples]
    
    # Mark selected nodes as text_attacked (last class)
    augmented_y[attack_indices] = num_classes  # text_attacked class
    
    # Create new training data
    augmented_train_data = Data(
        x=train_data.x,
        y=augmented_y,
        edge_index=train_data.edge_index,
        edge_attr=getattr(train_data, 'edge_attr', None)
    )
    
    return augmented_train_data


# ------------ Argument Parsing ------------
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data/")
parser.add_argument("--graph_save_dir", type=str, default="atkg")
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--atk_type", type=str, default='structure', choices=['structure', 'text', 'hybrid'])
parser.add_argument("--attack", type=str, default="pgd")
parser.add_argument("--atk_phase", type=str, default="inductive", choices=['inductive', 'transductive'])
parser.add_argument("--atk_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--def_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--re_split", type=int, default=2)
parser.add_argument("--ptb_rate", type=float, default=0.1)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--epochs", type=int, default=400)
parser.add_argument("--similarity_threshold", type=float, default=0.5)
parser.add_argument("--attack_ratio", type=float, default=0.15)
parser.add_argument("--use_existing_logs", action='store_true')
args = parser.parse_args()

# ------------ Load Config ------------
with open(args.config, 'r') as f:
    config = yaml.safe_load(f)

# Get GCN hyperparameters
model_config = config['hyperparams'].get('gcn', {})
hyperparams = {}
for param_name, param_values in model_config.items():
    hyperparams[param_name] = param_values

assert args.atk_phase == 'inductive'
if args.dataset != 'arxiv':
    assert args.re_split == 2, "Only inductive split with re_split=2 (6/2/2) is supported"
    seeds = range(3)
else:
    args.re_split = 0
    assert args.re_split == 0, "Only inductive split with re_split=0 (original split) is supported for dataset arxiv"
    seeds = [0]

# ------------ Fixed Settings ------------
root_path = args.root_path
emb_save_dir = f"{root_path}/datasets/{args.atk_emb_type}"
graph_save_dir = f"{root_path}/{args.graph_save_dir}"
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# ------------ Logging Setup ------------
log_dir = f"eval_logs_ind/{args.dataset}/auto_gcn/{args.atk_type}_{args.attack}_{args.atk_phase}"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/{args.atk_emb_type}_{args.def_emb_type}_ptb{int(args.ptb_rate*100)}_sim{int(args.similarity_threshold*100)}.log"

# Check if log file exists, skip experiment if it does
if args.use_existing_logs and os.path.exists(log_file):
    print(f"Log file {log_file} already exists. Experiment already completed. Skipping...")
    exit(0)

logging.basicConfig(filename=log_file, level=logging.INFO, filemode='w', format='%(message)s')
logger = logging.getLogger()

logger.info("=" * 80)
logger.info(f"EXPERIMENT - Dataset: {args.dataset}, Model: AutoGCN")
logger.info(f"Attack Type: {args.atk_type}, Attack: {args.attack}, Attack Phase: {args.atk_phase}")
logger.info(f"Embedding: atk={args.atk_emb_type}, def={args.def_emb_type}, ptb_rate={args.ptb_rate}")
logger.info(f"Auto Config: similarity_threshold={args.similarity_threshold}, attack_ratio={args.attack_ratio}")
logger.info(f"Training: epochs={args.epochs}, patience={args.patience}")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)

# ------------ Run over hyperparams ------------
param_names = list(hyperparams.keys())
param_values = list(hyperparams.values())
param_combinations = list(product(*param_values))

for params in param_combinations:
    param_dict = dict(zip(param_names, params))
    
    # Extract parameters
    lr = param_dict.get('learning_rates', 0.01)
    wd = param_dict.get('weight_decays', 0.0)
    dropout = param_dict.get('dropouts', 0.5)
    if args.dataset in ['cora', 'citeseer', 'pubmed', 'instagram', 'wikics']:
        hidden_dim = 128
    else:
        hidden_dim = 256
    
    clean_accs, evasion_accs = [], []
    clean_val_accs = []
    
    logger.info("-" * 60)
    logger.info(f"Hyperparams: {param_dict}")
    start_time = time.time()
    
    for seed in seeds:
        set_seed(seed)
        full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(
            args.dataset, device, re_split=args.re_split, path_prefix=root_path, 
            emb_model=args.def_emb_type, seed=seed
        )
        full_data = full_data.to(device)
        train_data = train_data.to(device)
        val_data = val_data.to(device)
        test_data = test_data.to(device)

        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        
        # Create training data with text_attacked samples
        augmented_train_data = create_training_data_with_attacks(train_data, args.attack_ratio)
        
        # Initialize AutoGCN model
        model = AutoGCN(
            num_features, num_classes, 
            hids=[hidden_dim], acts=['relu'], 
            dropout=dropout,
            similarity_threshold=args.similarity_threshold
        )
        
        trainer = Trainer(model, device=device, verbose=0)
        trainer.reset_optimizer(lr=lr, weight_decay=wd)
        ckp = ModelCheckpoint(f'auto_gcn_{args.dataset}_{args.attack}_{args.atk_emb_type}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')

        num_train_nodes = train_data.x.shape[0]
        num_val_nodes = val_data.x.shape[0] - num_train_nodes
        
        # Create validation mask
        val_mask_for_val_data = torch.cat([
            torch.zeros(num_train_nodes, dtype=torch.bool, device=device),
            torch.ones(num_val_nodes, dtype=torch.bool, device=device)
        ])
        
        # Training with augmented data (includes text_attacked class)
        trainer.fit((augmented_train_data, val_data), mask=(None, val_mask_for_val_data), 
                   verbose=1, callbacks=[ckp, early_stopping], epochs=args.epochs)
        
        # Evaluate on clean test set with auto logic
        model.eval()
        with torch.no_grad():
            test_logits, clean_stats = model(test_data.x, test_data.edge_index, 
                                           training=False, apply_auto_logic=True, return_stats=True)
            test_pred = torch.argmax(test_logits, dim=1)
            test_acc = (test_pred[full_data.test_mask] == test_data.y[full_data.test_mask]).float().mean()
        
        val_acc = ckp.best
        clean_val_accs.append(val_acc)
        clean_accs.append(test_acc.item())

        # Load attack graph and evaluate evasion with auto logic
        atk_path = f"{graph_save_dir}/{args.dataset}/{args.attack}/{args.atk_emb_type}_{int(args.ptb_rate*100)}_{seed}.pt"
        if not os.path.exists(atk_path):
            raise FileNotFoundError(f"Missing attack graph: {atk_path}")
        
        perturbed_edge_index = torch.load(atk_path, map_location=device)
        
        with torch.no_grad():
            evasion_logits, attack_stats = model(test_data.x, perturbed_edge_index, 
                                               training=False, apply_auto_logic=True, return_stats=True)
            evasion_pred = torch.argmax(evasion_logits, dim=1)
            evasion_acc = (evasion_pred[full_data.test_mask] == test_data.y[full_data.test_mask]).float().mean()
        
        evasion_accs.append(evasion_acc.item())
        
        # Log detailed auto statistics
        logger.info(f"Seed {seed}: Clean={clean_accs[-1]:.4f}, Evasion={evasion_accs[-1]:.4f}")
        logger.info(f"  Clean data auto stats:")
        logger.info(f"    - Nodes detected as text_attacked: {clean_stats.get('text_attacked_detected', 0)}")
        logger.info(f"    - Nodes using text recovery: {clean_stats.get('text_recovery_used', 0)}")
        logger.info(f"    - Nodes using structure recovery: {clean_stats.get('structure_recovery_used', 0)}")
        logger.info(f"    - Edges: {clean_stats.get('original_edges', 0)} → {clean_stats.get('filtered_edges', 0)} (removed: {clean_stats.get('edges_removed', 0)})")
        logger.info(f"  Attack data auto stats:")
        logger.info(f"    - Nodes detected as text_attacked: {attack_stats.get('text_attacked_detected', 0)}")
        logger.info(f"    - Nodes using text recovery: {attack_stats.get('text_recovery_used', 0)}")
        logger.info(f"    - Nodes using structure recovery: {attack_stats.get('structure_recovery_used', 0)}")
        logger.info(f"    - Edges: {attack_stats.get('original_edges', 0)} → {attack_stats.get('filtered_edges', 0)} (removed: {attack_stats.get('edges_removed', 0)})")

    # Log results
    logger.info(f"> Clean Val Acc: {np.mean(clean_val_accs):.4f} ± {np.std(clean_val_accs):.4f}")
    logger.info(f"> Clean Test Acc: {np.mean(clean_accs):.4f} ± {np.std(clean_accs):.4f}")
    logger.info(f"> Evasion Acc: {np.mean(evasion_accs):.4f} ± {np.std(evasion_accs):.4f}")
    logger.info(f"> Time: {(time.time() - start_time) / len(seeds):.2f}s")

logger.info("AutoGCN evaluation completed!")