import os 
import torch 
import json 
import sys 
import time
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
sys.path.append("../..")
from common import load_atk_graph_dataset_for_llaga, load_inductive_atk_graph_dataset, set_seed, get_cur_time, compute_acc_and_f1, save_checkpoint, reload_best_model
from common import MODEL_PATHs as llm_paths, UNKNOW
from common.model_path import get_model_save_path, check_model_exists
from graphgpt_model import GraphGPTModel
from dataset import GraphInstructionTuningDataset, GraphMatchingDataset, classes
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--dataset", type=str, default="cora")
    parser.add_argument("--re_split", type=int, default=0)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--llm", type=str, default="Mistral-7B")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--gpu_id", type=int, default=0)
    
    # Attack-specific arguments
    parser.add_argument("--attack", type=str, default="pgd", help="Attack type (pgd, grbcd, prbcd, text_fooler, etc.)")
    parser.add_argument("--atk_type", type=str, default="structure", choices=["structure", "text", "hybrid"], help="Attack category")
    parser.add_argument("--ptb_rate", type=float, default=0.1, help="Perturbation rate")
    parser.add_argument("--atk_emb_type", type=str, default="bow", help="Embedding type used for attack")
    parser.add_argument("--atk_seed", type=int, default=0, help="Seed used for attack generation")
    
    # Model source arguments - for loading pre-trained models
    parser.add_argument("--source_prompt", type=str, default=None, choices=["noise", "noisetxt", "noisefull"], 
                       help="Source model's prompt type (for loading pre-trained models with noise)")
    
    # Training Configuration for Stage 1 - (Self-supervised Training) Graph Matching 
    parser.add_argument("--do_stage1", type=int, default=1)
    parser.add_argument("--s1_k_hop", type=int, default=2)
    parser.add_argument("--s1_num_neighbors", type=int, default=5)
    parser.add_argument("--s1_max_txt_length", type=int, default=512)
    parser.add_argument("--s1_max_ans_length", type=int, default=256)
    parser.add_argument("--s1_epoch", type=int, default=2)
    parser.add_argument("--s1_batch_size", type=int, default=16)
    parser.add_argument("--s1_lr", type=float, default=1e-4)
    
    # Training Configuration for Stage 2 - Instruction Tuning
    parser.add_argument("--do_stage2", type=int, default=1)
    parser.add_argument("--s2_num_neighbors", type=int, default=4)
    parser.add_argument("--s2_max_txt_length", type=int, default=256)
    parser.add_argument("--s2_max_ans_length", type=int, default=16)
    parser.add_argument("--s2_epoch", type=int, default=10)
    parser.add_argument("--s2_batch_size", type=int, default=32)
    parser.add_argument("--s2_lr", type=float, default=1e-4)
    parser.add_argument("--s2_patience", type=int, default=2)
    
    parser.add_argument("--output_dim", type=int, default=2048)
    parser.add_argument("--wd", type=float, default=0.05)
    parser.add_argument("--load_ground_embedding", type=int, default=0)
    parser.add_argument("--output_dir", type=str, default="../../results/GraphGPT")
    
    args = parser.parse_args()
    
    print("=" * 20)
    print("## Starting Attack Evaluation Time:", get_cur_time(), flush=True)
    print(f"## Attack Info: {args.attack} with ptb_rate={args.ptb_rate}")
    print(args, "\n")
    
    device = torch.device("cuda:"+str(args.gpu_id))
    set_seed(args.seed)
    
    llm_path = llm_paths[args.llm]
    args.output_dim = {"Qwen-3B": 2048, "Qwen-7B": 3584, "Mistral-7B": 4096, "Llama-8B": 4096}[args.llm]
    
    # Prepare attack metadata
    atk_meta_info = {
        'attack': args.attack,
        'ptb_rate': args.ptb_rate,
        'atk_emb_type': args.atk_emb_type,
        'seed': args.atk_seed,
        'atk_type': args.atk_type
    }
    
    # Prepare Data 
    # Special case: arxiv is inductive even with re_split=0
    is_inductive = (args.re_split == 2) or (args.re_split == 0 and args.dataset == "arxiv")
    
    if is_inductive:
        # Inductive setting (including arxiv)
        print("Loading attacked inductive dataset...")
        full_graph_data, (train_data, val_data, test_data) = load_atk_graph_dataset_for_llaga(
            dataset_name=args.dataset, device=device, atk_meta_info=atk_meta_info,
            re_split=args.re_split, encoder="roberta", seed=args.seed
        )
        
        # Use separate graph embeddings for each phase
        train_graph_embedding = train_data.x.to(device)
        val_graph_embedding = val_data.x.to(device) 
        test_graph_embedding = test_data.x.to(device)
        
        # Store embeddings for different phases
        graph_embeddings = {
            'train': train_graph_embedding,
            'val': val_graph_embedding,
            'test': test_graph_embedding
        }
        graph_data = full_graph_data
        graph_embedding = train_graph_embedding  # Default to train for model initialization
    else:
        # Transductive setting
        graph_data = load_atk_graph_dataset_for_llaga(
            dataset_name=args.dataset, device=device, atk_meta_info=atk_meta_info, 
            encoder="roberta", re_split=args.re_split
        )
        
        # Apply additional noise based on source_prompt for transductive training
        if args.source_prompt == "noise":
            print(f"Applying additional noise to attacked data (transductive setting)")
            from common.noise_utils import apply_noise_to_graph_data
            # Create a copy for training with additional noise on top of attack
            train_graph_data = apply_noise_to_graph_data(graph_data, args.source_prompt, 0.1, args.seed)
            graph_embedding = train_graph_data.x.to(device)
            # Store both versions for dataset creation
            attacked_only_data = graph_data
            graph_data = train_graph_data  # Use noisy data for training
        else:
            # No additional noise, use attacked data for all phases
            graph_embedding = graph_data.x.to(device)
            attacked_only_data = graph_data
            
        graph_embeddings = None
        # For consistency, set split data to None for transductive setting
        train_data, val_data, test_data = None, None, None
        
    if args.load_ground_embedding: 
        assert os.path.exists(f"{args.output_dir}/ground_emb/{args.dataset}.pt"), "You have set `load_ground_embedding` to True. Please run `main_text_graph_grounding` to generate grounded embedding first!"
        ground_embedding = torch.load(f"{args.output_dir}/ground_emb/{args.dataset}.pt").to(device)
        if is_inductive:
            # For inductive setting, we need to split the ground embedding accordingly
            train_mask = graph_data.train_mask
            val_mask = graph_data.val_mask
            train_val_mask = train_mask | val_mask
            graph_embeddings = {
                'train': ground_embedding[train_mask],
                'val': ground_embedding[train_val_mask],
                'test': ground_embedding
            }
            graph_embedding = graph_embeddings['train']
        else:
            graph_embedding = ground_embedding
    
    # Log attack and dataset information
    if is_inductive:
        # Inductive setting (including arxiv)
        setting_type = 2  # Inductive
        setting_name = "inductive"
    else:
        setting_type = 0 if args.re_split == 0 else 1
        setting_name = "semi-supervised" if args.re_split == 0 else "supervised"
    
    # Generate model save path based on re_split logic
    # Special case: arxiv is inductive even with re_split=0
    is_inductive = (args.re_split == 2) or (args.re_split == 0 and args.dataset == "arxiv")
    
    if is_inductive:
        # For inductive setting (including arxiv), use clean model path (no attack info)
        model_save_path = get_model_save_path(
            model_name="GraphGPT",
            dataset=args.dataset,
            re_split=args.re_split,
            llm=args.llm,
            seed=args.seed,
            atk_name=None,  # Clean model has no attack name
            s1_epoch=args.s1_epoch,
            s2_epoch=args.s2_epoch,
            load_ground_embedding=args.load_ground_embedding,
            prompt=args.source_prompt  # Use source_prompt for loading pre-trained models
        )
        print(f"Clean model save path (for inductive/arxiv): {model_save_path}")
    else:
        # For transductive setting (re_split=1), use attack-specific path with _atk suffix
        model_save_path = get_model_save_path(
            model_name="GraphGPT",
            dataset=args.dataset,
            re_split=args.re_split,
            llm=args.llm,
            seed=args.seed,
            atk_name=atk_meta_info,  # Include attack info for poisoning attacks
            s1_epoch=args.s1_epoch,
            s2_epoch=args.s2_epoch,
            load_ground_embedding=args.load_ground_embedding,
            prompt=args.source_prompt  # Use source_prompt for loading pre-trained models
        )
        print(f"Attack model save path (for transductive): {model_save_path}")
    
    st_time = time.time()   
    if args.do_stage1: 
        print("Preparing Stage 1 [Graph Matching] ...")

        model = GraphGPTModel(args, llm_path, graph_embedding=graph_embedding, stage="matching", inductive_embs=graph_embeddings)
        
        # Check if stage1 model exists based on re_split logic
        stage1_model_path = os.path.join(model_save_path, "stage1_best.pth")
        if is_inductive:
            # For inductive setting (including arxiv), must load clean pre-trained stage1 model
            if not os.path.exists(stage1_model_path):
                raise FileNotFoundError(f"Clean pre-trained Stage 1 model not found: {stage1_model_path}")
            print(f"Loading clean pre-trained Stage 1 model from {stage1_model_path}")
            model = reload_best_model(model, model_save_path, config_str="stage1")
            args.s1_epoch = 0
            args.s1_lr = args.s2_lr
        elif args.re_split == 1:
            # For re_split=1, always retrain attack model from scratch
            print("Re_split=1, training attack Stage 1 model from scratch...")
            # Don't load existing model, proceed with training
        else:
            # For other transductive settings, check if attack model exists
            if os.path.exists(stage1_model_path):
                print(f"Loading existing attack Stage 1 model from {stage1_model_path}")
                model = reload_best_model(model, model_save_path, config_str="stage1")
                args.s1_epoch = 0
                args.s1_lr = args.s2_lr
            else:
                print("No existing attack Stage 1 model found. Training from scratch...")
        
        params = [p for _, p in model.named_parameters() if p.requires_grad]
        optimizer = torch.optim.AdamW([{'params': params, 'lr': args.s1_lr, 'weight_decay': args.wd}])
        
        trainable_params, all_params = model.print_trainable_params()
        print(f"Trainable params {trainable_params} || all params {all_params} || trainable% {100 * trainable_params / all_params:.5f}")
        
        if args.s1_epoch > 0:  # Train only if epochs > 0 (not loaded from checkpoint)
            graph_type = {
                "cora": "academic_network", "citeseer": "academic_network", "pubmed": "academic_network", "wikics": "academic_network", "arxiv": "academic_network", 
                "reddit": "social_network", "instagram": "social_network",
                "computer": "ecommerce_network", "photo": "ecommerce_network", "history": "ecommerce_network"
            }[args.dataset]
            dataset = GraphMatchingDataset(graph_data=graph_data, k_hop=args.s1_k_hop, num_sampled_neighbors=args.s1_num_neighbors, graph_type=graph_type, re_split=args.re_split)
            train_loader = DataLoader(dataset, batch_size=args.s1_batch_size, drop_last=True, shuffle=True)
        
            num_training_steps = args.s1_epoch * len(train_loader)
            progress_bar = tqdm(range(num_training_steps))
        
            model.model.gradient_checkpointing_enable()
            for epoch in range(args.s1_epoch):
                model.train()
                
                # Set phase for inductive learning
                if is_inductive:
                    model.set_phase('train')
            
                epoch_loss, accum_loss = 0.0, 0.0
            
                for step, batch in enumerate(train_loader):
                    optimizer.zero_grad() 
                
                    loss = model(batch)
                    loss.backward()
                
                    clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
                    optimizer.step() 
                    epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()
                
                    if (step+1) % 20 == 0:
                        lr = optimizer.param_groups[0]["lr"]
                        print(f"(Temporary) Step {step} in Epoch {epoch+1} Accum Loss {accum_loss:.4f}")
                        accum_loss = 0.0 
                
                    progress_bar.update(1)
            
                print(f"[TRAIN] Epoch {epoch+1}|{args.s1_epoch}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader):.5f}")
                save_checkpoint(model, args.s1_epoch, model_save_path, config_str="stage1", is_best=True) 
        
            torch.cuda.empty_cache()
            torch.cuda.reset_max_memory_allocated()
    
    if args.do_stage2: 
        print("Preparing Stage 2 [Instruction Tuning] ...")
        
        if not args.do_stage1: 
            model = GraphGPTModel(args, llm_path, graph_embedding=graph_embedding, stage="matching", inductive_embs=graph_embeddings)
        
            params = [p for _, p in model.named_parameters() if p.requires_grad]
            optimizer = torch.optim.AdamW([{'params': params, 'lr': args.s2_lr, 'weight_decay': args.wd}])
            trainable_params, all_params = model.print_trainable_params()
            print(f"Trainable params {trainable_params} || all params {all_params} || trainable% {100 * trainable_params / all_params:.5f}")
        else:
            print(f"Directly load stage1's pretrained graph projector layer")
            
        train_dataset = GraphInstructionTuningDataset(graph_data=graph_data, maximum_neighbors=args.s2_num_neighbors, dataset_name=args.dataset, data_type="train", re_split=args.re_split, split_data=train_data if is_inductive else None)
        val_dataset = GraphInstructionTuningDataset(graph_data=graph_data, maximum_neighbors=args.s2_num_neighbors, dataset_name=args.dataset, data_type="val", re_split=args.re_split, split_data=val_data if is_inductive else None)
        test_dataset = GraphInstructionTuningDataset(graph_data=graph_data, maximum_neighbors=args.s2_num_neighbors, dataset_name=args.dataset, data_type="test", re_split=args.re_split, split_data=test_data if is_inductive else None)
        
        train_loader = DataLoader(train_dataset, batch_size=args.s2_batch_size, drop_last=False, pin_memory=True, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.s2_batch_size*2, drop_last=False, pin_memory=True, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.s2_batch_size*3, drop_last=False, pin_memory=True, shuffle=False)
        
        num_training_steps = args.s2_epoch * len(train_loader)
        progress_bar = tqdm(range(num_training_steps))
        
        best_val_loss = float('inf')
        model.model.gradient_checkpointing_enable()
        
        llm_config_str = f"{args.llm}_Epoch{args.s2_epoch}"
        
        # Check if stage2 model exists based on re_split logic
        stage2_model_path = os.path.join(model_save_path, f"{llm_config_str}_best.pth")
        should_train_stage2 = True
        if is_inductive:
            # For inductive setting (including arxiv), must load clean pre-trained stage2 model
            if not os.path.exists(stage2_model_path):
                raise FileNotFoundError(f"Clean pre-trained Stage 2 model not found: {stage2_model_path}")
            print(f"Loading clean pre-trained Stage 2 model from {stage2_model_path}")
            model = reload_best_model(model, model_save_path, llm_config_str)
            print("Using clean pre-trained Stage 2 model for attack evaluation")
            should_train_stage2 = False
        elif args.re_split == 1:
            # For re_split=1, always retrain attack model from scratch
            print("Re_split=1, training attack Stage 2 model from scratch...")
            should_train_stage2 = True
        else:
            # For other transductive settings, check if attack model exists
            if os.path.exists(stage2_model_path):
                print(f"Loading existing attack Stage 2 model from {stage2_model_path}")
                model = reload_best_model(model, model_save_path, llm_config_str)
                print("Skipping Stage 2 training as attack model already exists")
                should_train_stage2 = False
            else:
                print("No existing attack Stage 2 model found. Training attack model from scratch...")
                should_train_stage2 = True

        if should_train_stage2:
            print("Training Stage 2...")
            for epoch in range(args.s2_epoch):
                model.train()
            
                # Set phase for training in inductive learning
                if is_inductive:
                    model.set_phase('train')
            
                epoch_loss, accum_loss = 0.0, 0.0 
            
                for step, batch in enumerate(train_loader):
                    optimizer.zero_grad()
                    loss = model(batch) 
                    loss.backward()
                
                    clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
                    optimizer.step() 
                    epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()

                    if (step + 1) % 4 == 0:
                        lr = optimizer.param_groups[0]["lr"]
                        accum_loss = 0.

                    progress_bar.update(1)
       
                print(f"[TRAIN] Epoch {epoch+1}|{args.s2_epoch}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader):.5f}")
            
                val_loss = 0.0 
                model.eval()
                
                # Set phase for validation in inductive learning
                if is_inductive:
                    model.set_phase('val')
                
                with torch.no_grad():
                    for step, batch in enumerate(val_loader):
                        loss = model(batch)
                        val_loss += loss.item()
                print(f"[VAL] Epoch {epoch+1}|{args.s2_epoch}: Val Loss {val_loss / len(val_loader):.5f}")
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss 
                    best_epoch = epoch
                    save_checkpoint(model, epoch, model_save_path, llm_config_str, is_best=True)
                
                if epoch - best_epoch >= args.s2_patience: 
                    print(f"[TRAIN] Early stop at epoch {epoch+1}")
                    break 
                model = reload_best_model(model, model_save_path, llm_config_str)
     
    train_secs = time.time() - st_time
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    
    model.eval()
    
    # Set phase for testing in inductive learning
    if is_inductive:
        model.set_phase('test')
    
    # Prepare test dataset
    test_dataset = GraphInstructionTuningDataset(
        graph_data=graph_data, maximum_neighbors=args.s2_num_neighbors, 
        dataset_name=args.dataset, data_type="test", re_split=args.re_split, 
        split_data=test_data if is_inductive else None
    )
    test_loader = DataLoader(test_dataset, batch_size=args.s2_batch_size*3, drop_last=False, pin_memory=True, shuffle=False)
    
    # Create attack prediction directory and ensure it exists
    attack_dir = f"attack_{args.attack}_ptb{int(args.ptb_rate*100)}"
    prediction_dir = os.path.join(args.output_dir, "prediction", attack_dir)
    os.makedirs(prediction_dir, exist_ok=True)
    re_split_str = '_s' if args.re_split else ''
    path = os.path.join(prediction_dir, f"{args.dataset}_{args.llm}{re_split_str}_seed{args.seed}.json")
    print(f"\n[Attack Prediction] Write predictions on {path} ...")
    progress_bar = tqdm(range(len(test_loader)))
    
    pred_labels, gt_labels = [], []
    st_time = time.time()
    valid_labels = classes[args.dataset]
    
    try:
        with open(path, 'w') as file:
            for step, batch in enumerate(test_loader):
                with torch.no_grad():
                    id_list, predictions = model.inference(batch)
                    
                    for node_idx, llm_pred in zip(id_list, predictions):
                        node_idx = node_idx.item()
                        pred_label = llm_pred[:llm_pred.index("</s>")] if "</s>" in llm_pred else llm_pred
                        
                        # Fix for inductive setting: proper node ID mapping
                        if is_inductive:
                            original_node_idx = test_data.node_ids[node_idx].item()
                            gt_label = classes[args.dataset][test_data.y[node_idx].item()]
                        else:
                            original_node_idx = node_idx
                            gt_label = classes[args.dataset][graph_data.y[node_idx].item()]
                        
                        write_obj = {
                            "id": original_node_idx,
                            "pred": llm_pred,
                            "ground-truth": gt_label,
                            "attack": args.attack,
                            "ptb_rate": args.ptb_rate,
                            "atk_emb_type": args.atk_emb_type,
                            "atk_seed": args.atk_seed,
                            "atk_type": args.atk_type,
                            "model": args.llm,
                            "s1_epoch": args.s1_epoch,
                            "s2_epoch": args.s2_epoch,
                            "data_seed": args.seed
                        }
                        pred_label = pred_label if pred_label in valid_labels else UNKNOW
                        pred_labels.append(pred_label) 
                        gt_labels.append(write_obj["ground-truth"])
                        file.write(json.dumps(write_obj) + "\n")
                        file.flush()
                    progress_bar.update(1)
    except Exception as e:
        print(f"Error during prediction or writing results: {e}")
    
    inference_secs = time.time() - st_time
    
    # Compute metrics
    acc, macrof1, weightf1 = compute_acc_and_f1(pred_labels, gt_labels)
    
    # Calculate unknown predictions ratio
    unknown_count = sum(1 for pred in pred_labels if pred == UNKNOW)
    unknown_ratio = unknown_count / len(pred_labels) if len(pred_labels) > 0 else 0
    
    print(f"Attack Results - {args.attack} (ptb_rate={args.ptb_rate}):")
    print(f"Accuracy {acc:.3f}  Macro F1-Score {macrof1:.3f}  Weight F1-Score {weightf1:.3f}")
    print(f"Unknown predictions: {unknown_count}/{len(pred_labels)} ({unknown_ratio:.2%})")
    print('\n## Finishing Time:', get_cur_time(), flush=True)
    print('= ' * 20)
    print("Attack evaluation done!")               