import os 
import csv
from tqdm import tqdm 
import torch 
import json 
import sys
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
sys.path.append("../..")
from common import load_graph_dataset_for_llaga, set_seed, get_cur_time, compute_acc_and_f1
from common import save_checkpoint, reload_best_model, MODEL_PATHs as llm_paths, UNKNOW
from common.model_path import get_model_save_path, check_model_exists
from common.noise_utils import apply_noise_to_graph_data, apply_noise_to_inductive_data
from llaga_model import LLaGAModel
import argparse
import time
from dataset import LLaGADataset, build_laplacian_emb, build_hopfield_emb, classes


def filter_edges_by_similarity(node_embeddings, edge_index, threshold=0.5):
    """Filter edges based on cosine similarity between connected nodes"""
    print(f"Original edges: {edge_index.shape[1]}")
    
    # Get source and target node indices
    src_nodes = edge_index[0]
    tgt_nodes = edge_index[1]
    
    # Get embeddings for source and target nodes
    src_embs = node_embeddings[src_nodes]
    tgt_embs = node_embeddings[tgt_nodes]
    
    # Compute cosine similarity
    similarities = torch.cosine_similarity(src_embs, tgt_embs, dim=1)
    
    # Filter edges where similarity >= threshold
    keep_mask = similarities >= threshold
    filtered_edge_index = edge_index[:, keep_mask]
    
    print(f"Filtered edges: {filtered_edge_index.shape[1]} (kept {keep_mask.sum().item()}/{len(keep_mask)} edges)")
    return filtered_edge_index


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--dataset", type=str, default="cora")
    parser.add_argument("--re_split", type=int, default=1)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--llm", type=str, default="Mistral-7B")
    parser.add_argument("--lm_encoder", type=str, default="roberta")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--token_counter", type=int, default=1)
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--num_gpus", type=int, default=1)
    
    # Noise configuration
    parser.add_argument("--prompt", type=str, default=None, choices=["noise", "noisetxt", "noisefull", "sim"], 
                       help="Noise type: 'noise' for structure noise, 'noisetxt' for text noise, 'noisefull' for both, None for no noise")
    parser.add_argument("--noise_ratio", type=float, default=0.1, 
                       help="Ratio of noise to add (default: 0.1 for 10%)")
    
    # Configuration of Neighborhood Encoding
    parser.add_argument("--neighbor_template", default="HO", choices=["ND", "HO"])
    parser.add_argument("--nd_mean", type=int, default=1)
    parser.add_argument("--k_hop", type=int, default=2)
    parser.add_argument("--sample_size", type=int, default=10)
    parser.add_argument("--hopfield", type=int, default=4)
    
    parser.add_argument("--max_txt_length", type=int, default=256)
    parser.add_argument("--max_ans_length", type=int, default=16)
    
    # Configuration of Linear Projection
    parser.add_argument("--n_linear_layer", type=int, default=2)
    parser.add_argument("--hidden_dim", type=int, default=2048)
    parser.add_argument("--output_dim", type=int, default=2048)
    
    # Configuration of Model Training 
    parser.add_argument("--num_epochs", type=int, default=6)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--eval_batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--wd", type=float, default=0.05)
    parser.add_argument("--output_dir", type=str, default="../../results/LLaGA")
    parser.add_argument("--grad_steps", type=int, default=4)
    parser.add_argument("--patience", type=int, default=4)
    parser.add_argument("--llm_freeze", type=int, default=1)
    
    args = parser.parse_args() 
    
    print("= " * 20)
    print("## Starting Time:", get_cur_time(), flush=True)
    if args.prompt is not None:
        print(f"## Noise Info: {args.prompt} with noise_ratio={args.noise_ratio}")
    print(args, "\n")
    
    # device = torch.device(args.device)
    device = torch.device("cuda:"+str(args.gpu_id))
    set_seed(args.seed)
    
    llm_path = llm_paths[args.llm]
    args.output_dim = {"Qwen-3B": 2048, "Qwen-7B": 3584, "Mistral-7B": 4096, "Llama-8B": 4096, "Qwen-14B": 5120, "Qwen-32B": 5120, 'Qwen3-8B': 4096, 'Ministral-8B': 4096}[args.llm]
    
    # Pre-process Node Classification Training Data - Support both transductive and inductive settings
    # Special case: arxiv is inductive even with re_split=0
    is_inductive = (args.re_split == 2) or (args.re_split == 0 and args.dataset == "arxiv")
    
    if is_inductive:
        # Inductive setting (including arxiv)
        print("Loading inductive dataset...")
        full_graph_data, (train_data, val_data, test_data) = load_graph_dataset_for_llaga(
            dataset_name=args.dataset, device=device, 
            re_split=args.re_split, encoder=args.lm_encoder, seed=args.seed
        )

        if args.prompt == "sim":
            print("Applying similarity filtering by modifying edge_index...")
            test_data.edge_index = filter_edges_by_similarity(test_data.x, test_data.edge_index, threshold=0.5)
            print(f"Edge filtering completed")
        
        # Apply noise only to training data for inductive setting
        if args.prompt is not None:
            print(f"Applying {args.prompt} noise to training data only (inductive setting)")
            train_data, val_data, test_data, full_graph_data = apply_noise_to_inductive_data(
                train_data, val_data, test_data, full_graph_data, args.prompt, args.noise_ratio, args.seed
            )
            
            # For text noise, regenerate embeddings for the noisy texts
            if args.prompt == "noisetxt" and args.lm_encoder != "shallow":
                print(f"Regenerating embeddings for noisy training texts using {args.lm_encoder}...")
                from common.lm import TextEncoder
                encoder_type = "LLM" if args.lm_encoder in ["Mistral-7B", "Qwen-7B", "Llama-8B", "Qwen3-8B", "Ministral-8B"] else "LM"
                
                # Clear GPU cache before creating encoder
                torch.cuda.empty_cache()
                
                text_encoder = TextEncoder(args.lm_encoder, encoder_type, device)
                
                with torch.no_grad():
                    new_embeddings = []
                    for i, text in enumerate(train_data.raw_texts):
                        text = "Empty text" if len(text) == 0 else text
                        emb = text_encoder.forward(text, pooling="mean")
                        new_embeddings.append(emb.cpu())
                        torch.cuda.empty_cache()
                        
                        if (i + 1) % 10 == 0:
                            print(f"Processed {i + 1}/{len(train_data.raw_texts)} training texts")
                    
                    # Update training data embeddings
                    train_data.x = torch.cat(new_embeddings, dim=0).to(device)
                
                # Clean up encoder
                if hasattr(text_encoder, 'model'):
                    del text_encoder.model
                del text_encoder
                torch.cuda.empty_cache()
                
                print(f"Updated training embeddings for {len(train_data.raw_texts)} texts")
            
            # For noisefull, also regenerate embeddings for noisy texts
            elif args.prompt == "noisefull" and args.lm_encoder != "shallow":
                print(f"Regenerating embeddings for noisy training texts (noisefull) using {args.lm_encoder}...")
                from common.lm import TextEncoder
                encoder_type = "LLM" if args.lm_encoder in ["Mistral-7B", "Qwen-7B", "Llama-8B", "Qwen3-8B", "Ministral-8B"] else "LM"
                
                # Clear GPU cache before creating encoder
                torch.cuda.empty_cache()
                
                text_encoder = TextEncoder(args.lm_encoder, encoder_type, device)
                
                with torch.no_grad():
                    new_embeddings = []
                    for i, text in enumerate(train_data.raw_texts):
                        text = "Empty text" if len(text) == 0 else text
                        emb = text_encoder.forward(text, pooling="mean")
                        new_embeddings.append(emb.cpu())
                        torch.cuda.empty_cache()
                        
                        if (i + 1) % 10 == 0:
                            print(f"Processed {i + 1}/{len(train_data.raw_texts)} training texts")
                    
                    # Update training data embeddings
                    train_data.x = torch.cat(new_embeddings, dim=0).to(device)
                
                # Clean up encoder
                if hasattr(text_encoder, 'model'):
                    del text_encoder.model
                del text_encoder
                torch.cuda.empty_cache()
                
                print(f"Updated training embeddings for {len(train_data.raw_texts)} texts (noisefull)")
        
        # Build separate hopfield embeddings for each phase
        train_hop_emb = build_hopfield_emb(train_data.x, train_data.edge_index, n_layers=args.hopfield)
        val_hop_emb = build_hopfield_emb(val_data.x, val_data.edge_index, n_layers=args.hopfield)  
        test_hop_emb = build_hopfield_emb(test_data.x, test_data.edge_index, n_layers=args.hopfield)
        
        # Store all hop embeddings for different phases
        hopfield_emb = {
            'train': train_hop_emb,
            'val': val_hop_emb, 
            'test': test_hop_emb
        }
        
        # Create datasets for each split
        train_dataset = LLaGADataset(args, graph_data=train_data, full_graph=full_graph_data, data_type="train", repeats=1, inductive=True)
        val_dataset = LLaGADataset(args, graph_data=val_data, full_graph=full_graph_data, data_type="val", repeats=1, inductive=True)
        test_dataset = LLaGADataset(args, graph_data=test_data, full_graph=full_graph_data, data_type="test", repeats=1, inductive=True)
        
        structure_emb = build_laplacian_emb(args.k_hop, args.sample_size).to(device)
        
    else:
        # Transductive setting (re_split=1)
        graph_data = load_graph_dataset_for_llaga(dataset_name=args.dataset, device=device, encoder=args.lm_encoder, re_split=args.re_split, seed=args.seed)
        
        # Apply noise to training data only for transductive setting
        if args.prompt is not None:
            print(f"Applying {args.prompt} noise to training data (transductive setting)")
            # Create a copy for training with noise
            train_graph_data = apply_noise_to_graph_data(graph_data, args.prompt, args.noise_ratio, args.seed)
            # Use clean data for validation and test
            val_test_graph_data = graph_data
        else:
            train_graph_data = graph_data
            val_test_graph_data = graph_data
        
        train_dataset = LLaGADataset(args, graph_data=train_graph_data, full_graph=graph_data, data_type="train", repeats=1)
        val_dataset = LLaGADataset(args, graph_data=val_test_graph_data, full_graph=graph_data, data_type="val", repeats=1)
        test_dataset = LLaGADataset(args, graph_data=val_test_graph_data, full_graph=graph_data, data_type="test", repeats=1)
        full_graph_data = graph_data
        
        # Build embeddings using appropriate data
        if args.prompt is not None:
            # For training: use noisy data embeddings
            hopfield_emb = build_hopfield_emb(train_graph_data.x, train_graph_data.edge_index, n_layers=args.hopfield)
        else:
            # For no noise: use original data
            hopfield_emb = build_hopfield_emb(graph_data.x, graph_data.edge_index, n_layers=args.hopfield)
        structure_emb = build_laplacian_emb(args.k_hop, args.sample_size).to(device)
    
    # Build embeddings using the full graph
    print(f"[DATA] {'Supervised 6:2:2' if args.re_split else 'Semi-supervised Setting'} # Train {len(train_dataset)} # Val {len(val_dataset)} # Test {len(test_dataset)}")
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=False, pin_memory=True, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False)
    
    # Load LLaGA Model 
    if is_inductive:
        # For inductive setting (including arxiv), use train data embeddings
        graph_embedding_for_model = train_data.x if args.neighbor_template == "ND" else hopfield_emb['train']
        
        model = LLaGAModel(args, 
                           llm_path, 
                           graph_embedding=graph_embedding_for_model, 
                           structure_embedding=structure_emb if args.neighbor_template == "ND" else None,
                           inductive_embs=hopfield_emb if args.neighbor_template == "HO" else None)
    else:
        # For transductive setting, use full graph embeddings
        graph_embedding_for_model = full_graph_data.x if args.neighbor_template == "ND" else hopfield_emb
        model = LLaGAModel(args, 
                           llm_path, 
                           graph_embedding=graph_embedding_for_model, 
                           structure_embedding=structure_emb if args.neighbor_template == "ND" else None)
    
    re_split_str = '_s' if args.re_split else ''
    
    # Generate model save path
    save_path = get_model_save_path(
        model_name="LLaGA",
        dataset=args.dataset,
        re_split=args.re_split,
        llm=args.llm,
        seed=args.seed,
        atk_name=None,  # Can be modified later for adversarial experiments
        neighbor_template=args.neighbor_template,
        num_epochs=args.num_epochs,
        lm_encoder=args.lm_encoder,
        prompt=args.prompt  # Add prompt parameter for noise experiments
    )
    print(f"Model save path: {save_path}")
    
    # Check if trained model already exists
    llm_config_str = f"{args.llm}_{args.neighbor_template}_Epoch{args.num_epochs}{re_split_str}{'_LoRA' if not args.llm_freeze else ''}"
    model_path = os.path.join(save_path, f"{llm_config_str}_best.pth")
    
    should_train = True
    train_secs = 0
    try:
        if os.path.exists(model_path):
            print(f"Loading existing model from {model_path}")
            model = reload_best_model(model, save_path, llm_config_str)
            print("Skipping training as model already exists")
            should_train = False
    except Exception as e:
        print(f"Error loading existing model: {e}")
        print("Proceeding with training new model")
        should_train = True
    
    # (Temporary) Token Counter to Decide MAX_ANS_LENGTH & MAX_TXT_LENGTH
    if args.token_counter: 
        input_lengths, txt_lengths, output_lengths = [], [], []
        for sample in train_dataset + val_dataset + test_dataset:
            encoded_query = model.tokenizer(sample["query"])
            encoded_txt = model.tokenizer(sample["origin_txt"])
            encoded_label = model.tokenizer(sample["label"])
            input_lengths.append(len(encoded_query["input_ids"]))
            txt_lengths.append(len(encoded_query["input_ids"]))
            output_lengths.append(len(encoded_label["input_ids"]))
        print(f"[ANALYSIS] # Avg Input Token {sum(input_lengths)/len(input_lengths):.3f} # Avg txt Token {sum(txt_lengths)/len(txt_lengths):.2f}  # Avg Output Token {sum(output_lengths)/len(output_lengths):.3f}  Max Output Token {max(output_lengths)}")

    params = [p for _, p in model.named_parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(
        [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}]
    )
    
    trainable_params, all_params = model.print_trainable_params()
    print(f"Trainable params {trainable_params} || all params {all_params} || trainable% {100 * trainable_params / all_params:.5f}")
    
    num_training_steps = args.num_epochs * len(train_loader)
    progress_bar = tqdm(range(num_training_steps))
    
    best_val_loss = float('inf')
    
    model.model.gradient_checkpointing_enable()
    llm_config_str = f"{args.llm}_{args.neighbor_template}_Epoch{args.num_epochs}{re_split_str}{'_LoRA' if not args.llm_freeze else ''}"
    st_time = time.time()
    
    if should_train:
        for epoch in range(args.num_epochs):
            model.train() 
            
            # Set phase for inductive learning
            if is_inductive:
                model.set_phase('train')
            
            epoch_loss, accum_loss = 0.0, 0.0 
            
            for step, batch in enumerate(train_loader):
                optimizer.zero_grad()
                
                loss = model(batch)
                loss.backward()
                
                clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
                optimizer.step() 
                epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()

                if (step + 1) % args.grad_steps == 0:
                    lr = optimizer.param_groups[0]["lr"]
                    accum_loss = 0.

                progress_bar.update(1)
            
            print(f"[TRAIN] Epoch {epoch+1}|{args.num_epochs}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader):.5f}")
            
            val_loss = 0.0 
            model.eval()
            
            # Set phase for validation in inductive learning
            if is_inductive:
                model.set_phase('val')
                
            with torch.no_grad():
                for step, batch in enumerate(val_loader):
                    loss = model(batch)
                    val_loss += loss.item()
                print(f"[VAL] Epoch: {epoch+1}|{args.num_epochs}: Val Loss: {val_loss / len(val_loader):.5f}")
                
            if val_loss < best_val_loss:
                best_val_loss = val_loss 
                best_epoch = epoch 
                save_checkpoint(model, epoch, save_path, llm_config_str, is_best=True)
                
            if epoch - best_epoch >= args.patience:
                print(f"[TRAIN] Early stop at epoch {epoch+1}")
                break 
            
            # Load best model after training
            model = reload_best_model(model, save_path, llm_config_str)
    
    # Explicitly close the training progress bar if it was created
    if should_train and 'progress_bar' in locals():
        progress_bar.close()
    
    train_secs = time.time() - st_time
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    
    model.eval()
    
    # Set phase for test in inductive learning
    if is_inductive:
        model.set_phase('test')
    
    # Create prediction directory
    prediction_dir = os.path.join(args.output_dir, "prediction")
    os.makedirs(prediction_dir, exist_ok=True)
    path = os.path.join(prediction_dir, f"{args.dataset}_{args.llm}{re_split_str}_seed{args.seed}.json")
    print(f"\n[Prediction] Write predictions on {path} ...")
    
    progress_bar_test = tqdm(range(len(test_loader)))
    pred_labels, gt_labels = [], []
    st_time = time.time()
    try:
        with open(path, 'w') as file:
            for step, batch in enumerate(test_loader):
                with torch.no_grad():
                    try:
                        id_list, predictions = model.inference(batch)
                        
                        for node_idx, llm_pred in zip(id_list, predictions):
                            if is_inductive:
                                # For inductive setting (including arxiv), get original node_idx and label
                                original_node_idx = test_data.node_ids[node_idx.item()].item()
                                gt_label = classes[args.dataset][test_data.y[node_idx.item()].item()]
                            else:
                                # For transductive setting
                                original_node_idx = node_idx.item()
                                gt_label = classes[args.dataset][full_graph_data.y[node_idx.item()].item()]
                            
                            pred_label = llm_pred[:llm_pred.index("</s>")] if "</s>" in llm_pred else llm_pred
                            pred_label = pred_label if pred_label in classes[args.dataset] else UNKNOW
                            write_obj = {
                                "id": original_node_idx,
                                "pred": llm_pred,
                                "ground-truth": gt_label
                            }
                            pred_labels.append(pred_label) 
                            gt_labels.append(write_obj["ground-truth"])
                            file.write(json.dumps(write_obj) + "\n")
                            file.flush()
                    except Exception as e:
                        print(f"Error processing batch {step}: {e}")
                        continue
                    progress_bar_test.update(1)
    except Exception as e:
        print(f"Error writing predictions to file: {e}")
    
    # Explicitly close the progress bar before SwanLab cleanup
    progress_bar_test.close()
    
    inference_secs = time.time() - st_time
    
    acc, macro_f1, weight_f1 = compute_acc_and_f1(pred_labels, gt_labels)
    # Save results to CSV
    os.makedirs(os.path.dirname(f"{args.output_dir}/summary{'_semi' if not args.re_split else ''}.csv"), exist_ok=True)
    try:
        with open(f"{args.output_dir}/summary{'_semi' if not args.re_split else ''}.csv", 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([args.dataset, args.llm, acc, macro_f1, weight_f1, 
                           args.neighbor_template, args.hidden_dim, args.lm_encoder, args.num_epochs, args.patience, args.batch_size, args.lr, args.seed,
                           f"Train Minutes-{train_secs/60:.3f}", f"Inference Seconds-{inference_secs:.2f}"])
    except Exception as e:
        print(f"Error saving results to CSV: {e}")
    print(f"Accuracy {acc:.2f}  Macro F1-Score {macro_f1:.2f}  Weight F1-Score {weight_f1:.2f}")
    print('\n## Finishing Time:', get_cur_time(), flush=True)
    print('= ' * 20)
    print("Done!")
