import os 
import csv
from tqdm import tqdm 
import torch 
import json 
import sys
import swanlab
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
sys.path.append("../..")
from common import load_atk_graph_dataset_for_llaga, set_seed, get_cur_time, compute_acc_and_f1
from common import save_checkpoint, reload_best_model, MODEL_PATHs as llm_paths, UNKNOW
from common.model_path import get_model_save_path, check_model_exists
from llaga_model import LLaGAModel
import argparse
import time
from dataset import LLaGADataset, build_laplacian_emb, build_hopfield_emb, classes


def filter_edges_by_similarity(node_embeddings, edge_index, threshold=0.5):
    """Filter edges based on cosine similarity between connected nodes"""
    print(f"Original edges: {edge_index.shape[1]}")
    
    # Get source and target node indices
    src_nodes = edge_index[0]
    tgt_nodes = edge_index[1]
    
    # Get embeddings for source and target nodes
    src_embs = node_embeddings[src_nodes]
    tgt_embs = node_embeddings[tgt_nodes]
    
    # Compute cosine similarity
    similarities = torch.cosine_similarity(src_embs, tgt_embs, dim=1)
    
    # Filter edges where similarity >= threshold
    keep_mask = similarities >= threshold
    filtered_edge_index = edge_index[:, keep_mask]
    
    print(f"Filtered edges: {filtered_edge_index.shape[1]} (kept {keep_mask.sum().item()}/{len(keep_mask)} edges)")
    return filtered_edge_index


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--dataset", type=str, default="cora")
    parser.add_argument("--re_split", type=int, default=1)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--llm", type=str, default="Mistral-7B")
    parser.add_argument("--lm_encoder", type=str, default="roberta")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--token_counter", type=int, default=1)
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--num_gpus", type=int, default=1)
    
    # Attack-specific arguments
    parser.add_argument("--attack", type=str, default="pgd", help="Attack type (pgd, grbcd, prbcd, text_fooler, etc.)")
    parser.add_argument("--atk_type", type=str, default="structure", choices=["structure", "text", "hybrid"], help="Attack category")
    parser.add_argument("--ptb_rate", type=float, default=0.1, help="Perturbation rate")
    parser.add_argument("--atk_emb_type", type=str, default="bow", help="Embedding type used for attack")
    parser.add_argument("--atk_seed", type=int, default=0, help="Seed used for attack generation")
    
    # Model source arguments - for loading pre-trained models
    parser.add_argument("--source_prompt", type=str, default=None, choices=["noise", "noisetxt", "sim", "noisefull"], 
                       help="Source model's prompt type (for loading pre-trained models with noise) or test variant (sim for similarity filtering)")
    
    # Configuration of Neighborhood Encoding
    parser.add_argument("--neighbor_template", default="HO", choices=["ND", "HO"])
    parser.add_argument("--nd_mean", type=int, default=1)
    parser.add_argument("--k_hop", type=int, default=2)
    parser.add_argument("--sample_size", type=int, default=10)
    parser.add_argument("--hopfield", type=int, default=4)
    
    parser.add_argument("--max_txt_length", type=int, default=256)
    parser.add_argument("--max_ans_length", type=int, default=16)
    
    # Configuration of Linear Projection
    parser.add_argument("--n_linear_layer", type=int, default=2)
    parser.add_argument("--hidden_dim", type=int, default=2048)
    parser.add_argument("--output_dim", type=int, default=2048)
    
    # Configuration of Model Training 
    parser.add_argument("--num_epochs", type=int, default=6)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--eval_batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--wd", type=float, default=0.05)
    parser.add_argument("--output_dir", type=str, default="../../results/LLaGA")
    parser.add_argument("--grad_steps", type=int, default=4)
    parser.add_argument("--patience", type=int, default=4)
    parser.add_argument("--llm_freeze", type=int, default=1)
    parser.add_argument("--use_swanlab", action="store_true")
    
    args = parser.parse_args() 
    
    # Initialize SwanLab for attack evaluation
    exp_name = f"{args.dataset}_{args.llm}_{args.neighbor_template}_{args.attack}_ptb{int(args.ptb_rate*100)}_seed{args.seed}"
    if args.use_swanlab:
        swanlab.init(
            project="LLaGA_Attack_Eval",
            experiment_name=exp_name,
            config=vars(args),
            description=f"LLaGA attack evaluation on {args.dataset} with {args.attack} attack (ptb_rate={args.ptb_rate})",
            mode="local",
        )
    
    print("= " * 20)
    print("## Starting Attack Evaluation Time:", get_cur_time(), flush=True)
    print(f"## Attack Info: {args.attack} with ptb_rate={args.ptb_rate}")
    print(args, "\n")
    
    device = torch.device("cuda:"+str(args.gpu_id))
    set_seed(args.seed)
    
    llm_path = llm_paths[args.llm]
    args.output_dim = {"Qwen-3B": 2048, "Qwen-7B": 3584, "Mistral-7B": 4096, "Llama-8B": 4096, "Qwen-14B": 5120, "Qwen-32B": 5120}[args.llm]
    
    # Prepare attack metadata
    atk_meta_info = {
        'attack': args.attack,
        'ptb_rate': args.ptb_rate,
        'atk_emb_type': args.atk_emb_type,
        'seed': args.atk_seed,
        'atk_type': args.atk_type
    }
    
    # Load attacked dataset
    is_inductive = (args.re_split == 2) or (args.re_split == 0 and args.dataset == "arxiv")
    if not is_inductive:
        # Transductive setting
        graph_data = load_atk_graph_dataset_for_llaga(
            dataset_name=args.dataset, device=device, atk_meta_info=atk_meta_info, 
            encoder=args.lm_encoder, re_split=args.re_split, seed=args.seed
        )
        
        # Apply additional noise based on source_prompt for transductive training
        if args.source_prompt in ["noise", "noisetxt", "noisefull"]:
            print(f"Applying additional {args.source_prompt} noise to attacked data (transductive setting)")
            from common.noise_utils import apply_noise_to_graph_data
            # Create a copy for training with additional noise on top of attack
            train_graph_data = apply_noise_to_graph_data(graph_data, args.source_prompt, 0.1, args.seed)
            
            # For text noise, regenerate embeddings for the noisy texts
            if args.source_prompt in ["noisetxt", "noisefull"]:
                print(f"Regenerating embeddings for noisy training texts using {args.lm_encoder}...")
                from common.lm import TextEncoder
                encoder_type = "LLM" if args.lm_encoder in ["Mistral-7B", "Qwen-7B", "Llama-8B", "Qwen3-8B", "Ministral-8B"] else "LM"
                
                torch.cuda.empty_cache()
                text_encoder = TextEncoder(args.lm_encoder, encoder_type, device)
                
                with torch.no_grad():
                    new_embeddings = []
                    for i, text in enumerate(train_graph_data.raw_texts):
                        text = "Empty text" if len(text) == 0 else text
                        emb = text_encoder.forward(text, pooling="mean")
                        new_embeddings.append(emb.cpu())
                        torch.cuda.empty_cache()
                        
                        if (i + 1) % 100 == 0:
                            print(f"Processed {i + 1}/{len(train_graph_data.raw_texts)} training texts")
                    
                    train_graph_data.x = torch.cat(new_embeddings, dim=0).to(device)
                
                if hasattr(text_encoder, 'model'):
                    del text_encoder.model
                del text_encoder
                torch.cuda.empty_cache()
                
                print(f"Updated training embeddings for {len(train_graph_data.raw_texts)} texts")
            
            # Use attacked data (without additional noise) for validation and test
            val_test_graph_data = graph_data
        else:
            # No additional noise, use attacked data for all phases
            train_graph_data = graph_data
            val_test_graph_data = graph_data
                
        train_dataset = LLaGADataset(args, graph_data=train_graph_data, full_graph=graph_data, data_type="train", repeats=1)
        val_dataset = LLaGADataset(args, graph_data=val_test_graph_data, full_graph=graph_data, data_type="val", repeats=1)
        test_dataset = LLaGADataset(args, graph_data=val_test_graph_data, full_graph=graph_data, data_type="test", repeats=1)
        full_graph_data = graph_data
        
        # Build embeddings using appropriate data
        if args.source_prompt in ["noise", "noisetxt", "noisefull"]:
            # For noise training: use noisy data embeddings
            hopfield_emb = build_hopfield_emb(train_graph_data.x, train_graph_data.edge_index, n_layers=args.hopfield)
        else:
            # For no additional noise: use attacked data
            hopfield_emb = build_hopfield_emb(graph_data.x, graph_data.edge_index, n_layers=args.hopfield)
        structure_emb = build_laplacian_emb(args.k_hop, args.sample_size).to(device)
        
    else:
        # Inductive setting
        print("Loading attacked inductive dataset...")
        full_graph_data, (train_data, val_data, test_data) = load_atk_graph_dataset_for_llaga(
            dataset_name=args.dataset, device=device, atk_meta_info=atk_meta_info,
            re_split=args.re_split, encoder=args.lm_encoder, seed=args.seed
        )
        
        # Apply similarity filtering for sim variant by modifying edge_index directly
        if args.source_prompt == "sim":
            print("Applying similarity filtering by modifying edge_index...")
            test_data.edge_index = filter_edges_by_similarity(test_data.x, test_data.edge_index, threshold=0.5)
            print(f"Edge filtering completed")
        
        # Build separate hopfield embeddings for each phase  
        train_hop_emb = build_hopfield_emb(train_data.x, train_data.edge_index, n_layers=args.hopfield)
        val_hop_emb = build_hopfield_emb(val_data.x, val_data.edge_index, n_layers=args.hopfield)  
        test_hop_emb = build_hopfield_emb(test_data.x, test_data.edge_index, n_layers=args.hopfield)
        
        hopfield_emb = {
            'train': train_hop_emb,
            'val': val_hop_emb, 
            'test': test_hop_emb
        }
        
        # Create datasets for each split
        train_dataset = LLaGADataset(args, graph_data=train_data, full_graph=full_graph_data, data_type="train", repeats=1, inductive=True)
        val_dataset = LLaGADataset(args, graph_data=val_data, full_graph=full_graph_data, data_type="val", repeats=1, inductive=True)
        test_dataset = LLaGADataset(args, graph_data=test_data, full_graph=full_graph_data, data_type="test", repeats=1, inductive=True)
        
        structure_emb = build_laplacian_emb(args.k_hop, args.sample_size).to(device)
    
    print(f"[DATA] Attack Setting: {args.attack} ptb_rate={args.ptb_rate} | # Train {len(train_dataset)} # Val {len(val_dataset)} # Test {len(test_dataset)}")
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=False, pin_memory=True, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False)
    
    # Load pre-trained clean model (without attack for training)
    # Use the same logic as is_inductive calculation
    
    if is_inductive:
        # For inductive setting, use train data embeddings for model loading
        graph_embedding_for_model = train_data.x if args.neighbor_template == "ND" else hopfield_emb['train']
        model = LLaGAModel(args, 
                           llm_path, 
                           graph_embedding=graph_embedding_for_model, 
                           structure_embedding=structure_emb if args.neighbor_template == "ND" else None,
                           inductive_embs=hopfield_emb if args.neighbor_template == "HO" else None)
    else:
        # For transductive setting, use full graph embeddings
        graph_embedding_for_model = full_graph_data.x if args.neighbor_template == "ND" else hopfield_emb
        model = LLaGAModel(args, 
                           llm_path, 
                           graph_embedding=graph_embedding_for_model, 
                           structure_embedding=structure_emb if args.neighbor_template == "ND" else None)
    
    re_split_str = '_s' if args.re_split else ''
    
    # Generate model save path based on re_split value
    # Using the is_inductive variable defined at line 128
    
    if is_inductive:
        # For inductive setting (including arxiv), use clean model path (no attack info)
        # For sim variant, use clean model without sim suffix since it doesn't affect training
        model_prompt = args.source_prompt if args.source_prompt in ["noise", "noisetxt", "noisefull"] else None
        save_path = get_model_save_path(
            model_name="LLaGA",
            dataset=args.dataset,
            re_split=args.re_split,
            llm=args.llm,
            seed=args.seed,
            atk_name=None,  # Clean model has no attack name
            neighbor_template=args.neighbor_template,
            num_epochs=args.num_epochs,
            lm_encoder=args.lm_encoder,
            prompt=model_prompt  # Use source_prompt for loading pre-trained models (excluding sim)
        )
        print(f"Clean model save path (for inductive/arxiv): {save_path}")
    else:
        # For transductive setting (re_split=1), use attack-specific path with _atk suffix
        # For sim variant, use clean model without sim suffix since it doesn't affect training
        model_prompt = args.source_prompt if args.source_prompt in ["noise", "noisetxt", "noisefull"] else None
        save_path = get_model_save_path(
            model_name="LLaGA",
            dataset=args.dataset,
            re_split=args.re_split,
            llm=args.llm,
            seed=args.seed,
            atk_name=atk_meta_info,  # Include attack info for poisoning attacks
            neighbor_template=args.neighbor_template,
            num_epochs=args.num_epochs,
            lm_encoder=args.lm_encoder,
            prompt=model_prompt  # Use source_prompt for loading pre-trained models (excluding sim)
        )
        print(f"Attack model save path (for transductive): {save_path}")
    
    # Check if trained model already exists
    llm_config_str = f"{args.llm}_{args.neighbor_template}_Epoch{args.num_epochs}{re_split_str}{'_LoRA' if not args.llm_freeze else ''}"
    model_path = os.path.join(save_path, f"{llm_config_str}_best.pth")
    
    should_train = True
    train_secs = 0
    
    if is_inductive:  # For inductive setting (including arxiv) - load clean pre-trained model
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Clean pre-trained model not found: {model_path}")
        print(f"Loading clean pre-trained model from {model_path}")
        model = reload_best_model(model, save_path, llm_config_str)
        should_train = False
    else:  # For transductive setting (re_split=1) - check for attack model
        if args.re_split == 1:
            # For re_split=1, always retrain with attack data
            print("Re_split=1, training attack model from scratch...")
            should_train = True
        else:
            # For other transductive settings, check if attack model exists
            if os.path.exists(model_path):
                print(f"Loading existing attack model from {model_path}")
                model = reload_best_model(model, save_path, llm_config_str)
                should_train = False
            else:
                print("No existing attack model found. Training attack model from scratch...")
                should_train = True
    
    # (Temporary) Token Counter to Decide MAX_ANS_LENGTH & MAX_TXT_LENGTH
    if args.token_counter: 
        input_lengths, txt_lengths, output_lengths = [], [], []
        for sample in train_dataset + val_dataset + test_dataset:
            encoded_query = model.tokenizer(sample["query"])
            encoded_txt = model.tokenizer(sample["origin_txt"])
            encoded_label = model.tokenizer(sample["label"])
            input_lengths.append(len(encoded_query["input_ids"]))
            txt_lengths.append(len(encoded_txt["input_ids"]))
            output_lengths.append(len(encoded_label["input_ids"]))
        print(f"[ANALYSIS] # Avg Input Token {sum(input_lengths)/len(input_lengths):.3f} # Avg txt Token {sum(txt_lengths)/len(txt_lengths):.2f}  # Avg Output Token {sum(output_lengths)/len(output_lengths):.3f}  Max Output Token {max(output_lengths)}")

    params = [p for _, p in model.named_parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(
        [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}]
    )
    
    trainable_params, all_params = model.print_trainable_params()
    print(f"Trainable params {trainable_params} || all params {all_params} || trainable% {100 * trainable_params / all_params:.5f}")
    
    num_training_steps = args.num_epochs * len(train_loader)
    progress_bar = tqdm(range(num_training_steps))
    
    best_val_loss = float('inf')
    
    model.model.gradient_checkpointing_enable()
    st_time = time.time()
    
    if should_train:
        for epoch in range(args.num_epochs):
            model.train() 
            
            # Set phase for inductive learning
            if is_inductive:
                model.set_phase('train')
            
            epoch_loss, accum_loss = 0.0, 0.0 
            
            for step, batch in enumerate(train_loader):
                optimizer.zero_grad()
                loss = model(batch)
                loss.backward()
                
                clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
                optimizer.step() 
                epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()

                if (step + 1) % args.grad_steps == 0:
                    lr = optimizer.param_groups[0]["lr"]
                    # Log training metrics
                    if args.use_swanlab:
                        swanlab.log({
                            "training/step": epoch * len(train_loader) + step,
                            "training/epoch": epoch + 1,
                            "training/loss": accum_loss / args.grad_steps,
                            "training/learning_rate": lr,
                        })
                    accum_loss = 0.

                progress_bar.update(1)
            
            print(f"[TRAIN] Epoch {epoch+1}|{args.num_epochs}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader):.5f}")
            
            val_loss = 0.0 
            model.eval()
            
            # Set phase for validation in inductive learning
            if is_inductive:
                model.set_phase('val')
                
            with torch.no_grad():
                for step, batch in enumerate(val_loader):
                    loss = model(batch)
                    val_loss += loss.item()
                print(f"[VAL] Epoch: {epoch+1}|{args.num_epochs}: Val Loss: {val_loss / len(val_loader):.5f}")
            
                # Log epoch completion and validation metrics
                if args.use_swanlab:
                    swanlab.log({
                        "training/epoch_train_loss": epoch_loss / len(train_loader),
                        "training/epoch_val_loss": val_loss / len(val_loader),
                        "training/epoch": epoch + 1,
                        "training/is_best": val_loss < best_val_loss,
                    })
                
            if val_loss < best_val_loss:
                best_val_loss = val_loss 
                best_epoch = epoch 
                save_checkpoint(model, epoch, save_path, llm_config_str, is_best=True)
                
            if epoch - best_epoch >= args.patience:
                print(f"[TRAIN] Early stop at epoch {epoch+1}")
                break 
            
            # Load best model after training
            model = reload_best_model(model, save_path, llm_config_str)
    
    # Explicitly close the training progress bar if it was created
    if should_train and 'progress_bar' in locals():
        progress_bar.close()
    
    train_secs = time.time() - st_time
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    
    model.eval()
    
    # Set phase for testing in inductive learning
    if is_inductive:
        model.set_phase('test')
    
    # Create attack prediction directory
    attack_dir = f"attack_{args.attack}_ptb{int(args.ptb_rate*100)}"
    prediction_dir = os.path.join(args.output_dir, "prediction", attack_dir)
    os.makedirs(prediction_dir, exist_ok=True)
    path = os.path.join(prediction_dir, f"{args.dataset}_{args.llm}{re_split_str}_seed{args.seed}.json")
    print(f"\n[Attack Prediction] Write predictions on {path} ...")
    
    progress_bar_test = tqdm(range(len(test_loader)))
    pred_labels, gt_labels = [], []
    st_time = time.time()
    
    try:
        with open(path, 'w') as file:
            for step, batch in enumerate(test_loader):
                with torch.no_grad():
                    try:
                        id_list, predictions = model.inference(batch)
                        
                        for node_idx, llm_pred in zip(id_list, predictions):
                            if is_inductive:
                                original_node_idx = test_data.node_ids[node_idx.item()].item()
                                gt_label = classes[args.dataset][test_data.y[node_idx.item()].item()]
                            else:
                                original_node_idx = node_idx.item()
                                gt_label = classes[args.dataset][full_graph_data.y[node_idx.item()].item()]
                            
                            pred_label = llm_pred[:llm_pred.index("</s>")] if "</s>" in llm_pred else llm_pred
                            pred_label = pred_label if pred_label in classes[args.dataset] else UNKNOW
                            write_obj = {
                                "id": original_node_idx,
                                "pred": llm_pred,
                                "ground-truth": gt_label,
                                "attack": args.attack,
                                "ptb_rate": args.ptb_rate,
                                "atk_emb_type": args.atk_emb_type,
                                "atk_seed": args.atk_seed,
                                "atk_type": args.atk_type,
                                "model": args.llm,
                                "neighbor_template": args.neighbor_template,
                                "data_seed": args.seed
                            }
                            pred_labels.append(pred_label) 
                            gt_labels.append(write_obj["ground-truth"])
                            file.write(json.dumps(write_obj) + "\n")
                            file.flush()
                    except Exception as e:
                        print(f"Error processing batch {step}: {e}")
                        import traceback
                        traceback.print_exc()
                        continue
                    progress_bar_test.update(1)
    except Exception as e:
        print(f"Error writing predictions to file: {e}")
    
    progress_bar_test.close()
    
    inference_secs = time.time() - st_time
    
    acc, macro_f1, weight_f1 = compute_acc_and_f1(pred_labels, gt_labels)
    
    # Calculate unknown predictions ratio
    unknown_count = sum(1 for pred in pred_labels if pred == UNKNOW)
    unknown_ratio = unknown_count / len(pred_labels) if len(pred_labels) > 0 else 0
    
    # Save results to CSV
    os.makedirs(os.path.dirname(f"{args.output_dir}/summary_attack{'_semi' if not args.re_split else ''}.csv"), exist_ok=True)
    try:
        with open(f"{args.output_dir}/summary_attack{'_semi' if not args.re_split else ''}.csv", 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([args.dataset, args.llm, acc, macro_f1, weight_f1, 
                           args.neighbor_template, args.hidden_dim, args.lm_encoder, args.num_epochs, args.patience, args.batch_size, args.lr, args.seed,
                           args.attack, args.ptb_rate, args.atk_emb_type, args.atk_seed,
                           f"Inference Seconds-{inference_secs:.2f}"])
    except Exception as e:
        print(f"Error saving results to CSV: {e}")
    
    print(f"Attack Results - {args.attack} (ptb_rate={args.ptb_rate}):")
    print(f"Accuracy {acc:.2f}  Macro F1-Score {macro_f1:.2f}  Weight F1-Score {weight_f1:.2f}")
    print(f"Unknown predictions: {unknown_count}/{len(pred_labels)} ({unknown_ratio:.2%})")
    print('\n## Finishing Time:', get_cur_time(), flush=True)
    print('= ' * 20)
    print("Attack evaluation done!")

    # Log final results
    if args.use_swanlab:
        swanlab.log({
            "final_results/accuracy": acc,
            "final_results/macro_f1": macro_f1,
            "final_results/weighted_f1": weight_f1,
            "final_results/inference_time_seconds": inference_secs,
            "final_results/total_predictions": len(pred_labels),
            "final_results/unknown_predictions": unknown_count,
            "final_results/unknown_ratio": unknown_ratio,
        })
    
    # Finish SwanLab logging
    if args.use_swanlab:
        swanlab.finish() 