import os 
import torch 
import json 
import sys 
import time
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
sys.path.append("../..")
from common import load_graph_dataset_for_llaga, load_inductive_graph_dataset, set_seed, get_cur_time, compute_acc_and_f1, save_checkpoint, reload_best_model
from common import MODEL_PATHs as llm_paths, UNKNOW
from common.model_path import get_model_save_path, check_model_exists
from common.noise_utils import apply_noise_to_graph_data, apply_noise_to_inductive_data
from graphgpt_model import GraphGPTModel
from dataset import GraphInstructionTuningDataset, GraphMatchingDataset, classes
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--dataset", type=str, default="cora")
    parser.add_argument("--re_split", type=int, default=0)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--llm", type=str, default="Mistral-7B")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--gpu_id", type=int, default=0)
    
    # Noise configuration
    parser.add_argument("--prompt", type=str, default=None, choices=["noise", "noisetxt"], 
                       help="Noise type: 'noise' for structure noise, 'noisetxt' for text noise, None for no noise")
    parser.add_argument("--noise_ratio", type=float, default=0.1, 
                       help="Ratio of noise to add (default: 0.1 for 10%)")
    
    # Training Configuration for Stage 1 - (Self-supervised Training) Graph Matching 
    parser.add_argument("--do_stage1", type=int, default=1)
    parser.add_argument("--s1_k_hop", type=int, default=2)
    parser.add_argument("--s1_num_neighbors", type=int, default=5)
    parser.add_argument("--s1_max_txt_length", type=int, default=512)
    parser.add_argument("--s1_max_ans_length", type=int, default=256)
    parser.add_argument("--s1_epoch", type=int, default=2)
    parser.add_argument("--s1_batch_size", type=int, default=16)
    parser.add_argument("--s1_lr", type=float, default=1e-4)
    
    # Training Configuration for Stage 2 - Instruction Tuning
    parser.add_argument("--do_stage2", type=int, default=1)
    parser.add_argument("--s2_num_neighbors", type=int, default=4)
    parser.add_argument("--s2_max_txt_length", type=int, default=256)
    parser.add_argument("--s2_max_ans_length", type=int, default=16)
    parser.add_argument("--s2_epoch", type=int, default=10)
    parser.add_argument("--s2_batch_size", type=int, default=32) # adjust this parameter based on devices
    parser.add_argument("--s2_lr", type=float, default=1e-4)
    parser.add_argument("--s2_patience", type=int, default=2)
    
    parser.add_argument("--output_dim", type=int, default=2048)
    parser.add_argument("--wd", type=float, default=0.05)
    parser.add_argument("--load_ground_embedding", type=int, default=0)
    parser.add_argument("--output_dir", type=str, default="../../results/GraphGPT")
    
    args = parser.parse_args()
    
    print("=" * 20)
    print("## Starting Time:", get_cur_time(), flush=True)
    if args.prompt is not None:
        print(f"## Noise Info: {args.prompt} with noise_ratio={args.noise_ratio}")
    print(args, "\n")
    
    # device = torch.device(args.device)
    device = torch.device("cuda:"+str(args.gpu_id))
    set_seed(args.seed)
    
    llm_path = llm_paths[args.llm]
    args.output_dim = {"Qwen-3B": 2048, "Qwen-7B": 3584, "Mistral-7B": 4096, "Llama-8B": 4096}[args.llm]
    
    # Prepare Data 
    # Special case: arxiv is inductive even with re_split=0
    is_inductive = (args.re_split == 2) or (args.re_split == 0 and args.dataset == "arxiv")
    
    if is_inductive:
        # Inductive setting (including arxiv)
        print("Loading inductive dataset...")
        full_graph_data, (train_data, val_data, test_data) = load_graph_dataset_for_llaga(
            dataset_name=args.dataset, device=device, 
            re_split=args.re_split, encoder="roberta", seed=args.seed
        )
        
        # Apply noise only to training data for inductive setting
        if args.prompt is not None:
            print(f"Applying {args.prompt} noise to training data only (inductive setting)")
            train_data, val_data, test_data, full_graph_data = apply_noise_to_inductive_data(
                train_data, val_data, test_data, full_graph_data, args.prompt, args.noise_ratio, args.seed
            )
            
            # For text noise, regenerate embeddings for the noisy texts
            if args.prompt == "noisetxt":
                print(f"Regenerating embeddings for noisy training texts using roberta...")
                from common.lm import TextEncoder
                
                try:
                    torch.cuda.empty_cache()
                    text_encoder = TextEncoder("roberta", "LM", device)
                    
                    with torch.no_grad():
                        new_embeddings = []
                        for i, text in enumerate(train_data.raw_texts):
                            # Ensure text is a string and handle None/empty cases
                            if text is None or text == "":
                                text = "Empty text"
                            else:
                                text = str(text).strip()
                                if len(text) == 0:
                                    text = "Empty text"
                            
                            try:
                                emb = text_encoder.forward(text, pooling="cls", max_length=512)
                                new_embeddings.append(emb.cpu())
                            except Exception as e:
                                print(f"Error processing text {i}: {e}")
                                # Use a default embedding if processing fails
                                dummy_text = "default text"
                                emb = text_encoder.forward(dummy_text, pooling="cls", max_length=512)
                                new_embeddings.append(emb.cpu())
                            
                            torch.cuda.empty_cache()
                            
                            if (i + 1) % 10 == 0:
                                print(f"Processed {i + 1}/{len(train_data.raw_texts)} training texts")
                        
                        train_data.x = torch.cat(new_embeddings, dim=0).to(device)
                    
                    if hasattr(text_encoder, 'model'):
                        del text_encoder.model
                    del text_encoder
                    torch.cuda.empty_cache()
                    
                    print(f"Updated training embeddings for {len(train_data.raw_texts)} texts")
                
                except Exception as e:
                    print(f"Error in text noise processing: {e}")
                    print("Continuing with original embeddings...")
        
        # Use separate graph embeddings for each phase
        train_graph_embedding = train_data.x.to(device)
        val_graph_embedding = val_data.x.to(device) 
        test_graph_embedding = test_data.x.to(device)
        
        # Store embeddings for different phases
        graph_embeddings = {
            'train': train_graph_embedding,
            'val': val_graph_embedding,
            'test': test_graph_embedding
        }
        graph_data = full_graph_data
        graph_embedding = train_graph_embedding  # Default to train for model initialization
    else:
        # Transductive setting (re_split=1)
        graph_data = load_graph_dataset_for_llaga(dataset_name=args.dataset, device=device, encoder="roberta", re_split=args.re_split)
        
        # Apply noise to training data only for transductive setting
        if args.prompt is not None:
            print(f"Applying {args.prompt} noise to training data (transductive setting)")
            # For stage 1 training and stage 2 training, we need the noisy data
            noisy_graph_data = apply_noise_to_graph_data(graph_data, args.prompt, args.noise_ratio, args.seed)
            
            # For text noise, regenerate embeddings for the noisy texts
            if args.prompt == "noisetxt":
                print(f"Regenerating embeddings for noisy texts using roberta...")
                from common.lm import TextEncoder
                
                torch.cuda.empty_cache()
                text_encoder = TextEncoder("roberta", "LM", device)
                
                with torch.no_grad():
                    new_embeddings = []
                    for i, text in enumerate(noisy_graph_data.raw_texts):
                        text = "Empty text" if len(text) == 0 else text
                        emb = text_encoder.forward(text, pooling="cls", max_length=512)
                        new_embeddings.append(emb.cpu())
                        torch.cuda.empty_cache()
                        
                        if (i + 1) % 100 == 0:
                            print(f"Processed {i + 1}/{len(noisy_graph_data.raw_texts)} texts")
                    
                    noisy_graph_data.x = torch.cat(new_embeddings, dim=0).to(device)
                
                if hasattr(text_encoder, 'model'):
                    del text_encoder.model
                del text_encoder
                torch.cuda.empty_cache()
                
                print(f"Updated embeddings for {len(noisy_graph_data.raw_texts)} texts")
            
            # Use noisy data for training
            graph_embedding = noisy_graph_data.x.to(device)
            graph_embeddings = None
            # Store both clean and noisy data
            clean_graph_data = graph_data
            graph_data = noisy_graph_data  # Use noisy data for training
        else:
            graph_embedding = graph_data.x.to(device)
            graph_embeddings = None
            clean_graph_data = graph_data
        
        # For consistency, set split data to None for transductive setting
        train_data, val_data, test_data = None, None, None
        
    if args.load_ground_embedding: 
        assert os.path.exists(f"{args.output_dir}/ground_emb/{args.dataset}.pt"), "You have set `load_ground_embedding` to True. Please run `main_text_graph_grounding` to generate grounded embedding first!"
        ground_embedding = torch.load(f"{args.output_dir}/ground_emb/{args.dataset}.pt").to(device)
        if is_inductive:
            # For inductive setting, we need to split the ground embedding accordingly
            train_mask = graph_data.train_mask
            val_mask = graph_data.val_mask
            train_val_mask = train_mask | val_mask
            graph_embeddings = {
                'train': ground_embedding[train_mask],
                'val': ground_embedding[train_val_mask],
                'test': ground_embedding
            }
            graph_embedding = graph_embeddings['train']
        else:
            graph_embedding = ground_embedding
    
    # Generate model save path
    save_path = get_model_save_path(
        model_name="GraphGPT",
        dataset=args.dataset,
        re_split=args.re_split,
        llm=args.llm,
        seed=args.seed,
        atk_name=None,  # Can be modified later for adversarial experiments
        s1_epoch=args.s1_epoch,
        s2_epoch=args.s2_epoch,
        load_ground_embedding=args.load_ground_embedding,
        prompt=args.prompt  # Add prompt parameter for noise experiments
    )
    print(f"Model save path: {save_path}")
    
    st_time = time.time()   
    if args.do_stage1: 
        print("Preparing Stage 1 [Graph Matching] ...")

        model = GraphGPTModel(args, llm_path, graph_embedding=graph_embedding, stage="matching", inductive_embs=graph_embeddings)
        
        # Check if stage1 model exists
        stage1_model_path = os.path.join(save_path, "stage1_best.pth")
        if os.path.exists(stage1_model_path):
            print(f"Loading existing Stage 1 model from {stage1_model_path}")
            model = reload_best_model(model, save_path, config_str="stage1")
            args.s1_epoch = 0
            args.s1_lr = args.s2_lr
        
        params = [p for _, p in model.named_parameters() if p.requires_grad]
        optimizer = torch.optim.AdamW([{'params': params, 'lr': args.s1_lr, 'weight_decay': args.wd}])
        
        trainable_params, all_params = model.print_trainable_params()
        print(f"Trainable params {trainable_params} || all params {all_params} || trainable% {100 * trainable_params / all_params:.5f}")
        
        if args.s1_epoch > 0:
            graph_type = {
                "cora": "academic_network", "citeseer": "academic_network", "pubmed": "academic_network", "wikics": "academic_network", "arxiv": "academic_network", 
                "reddit": "social_network", "instagram": "social_network",
                "computer": "ecommerce_network", "photo": "ecommerce_network", "history": "ecommerce_network"
            }[args.dataset]
            dataset = GraphMatchingDataset(graph_data=graph_data, k_hop=args.s1_k_hop, num_sampled_neighbors=args.s1_num_neighbors, graph_type=graph_type, re_split=args.re_split)
            train_loader = DataLoader(dataset, batch_size=args.s1_batch_size, drop_last=True, shuffle=True)
        
            num_training_steps = args.s1_epoch * len(train_loader)
            progress_bar = tqdm(range(num_training_steps))
        
            model.model.gradient_checkpointing_enable()
            for epoch in range(args.s1_epoch):
                model.train()
                
                # Set phase for inductive learning
                if is_inductive:
                    model.set_phase('train')
            
                epoch_loss, accum_loss = 0.0, 0.0
            
                for step, batch in enumerate(train_loader):
                    optimizer.zero_grad() 
                
                    loss = model(batch)
                    loss.backward()
                
                    clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
                    optimizer.step() 
                    epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()
                
                    if (step+1) % 20 == 0:
                        lr = optimizer.param_groups[0]["lr"]
                        print(f"(Temporary) Step {step} in Epoch {epoch+1} Accum Loss {accum_loss:.4f}")
                        accum_loss = 0.0 
                
                    progress_bar.update(1)
            
                print(f"[TRAIN] Epoch {epoch+1}|{args.s1_epoch}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader):.5f}")
                save_checkpoint(model, args.s1_epoch, save_path, config_str="stage1", is_best=True) 
        
            torch.cuda.empty_cache()
            torch.cuda.reset_max_memory_allocated()
    
    if args.do_stage2: 
        print("Preparing Stage 2 [Instruction Tuning] ...")
        
        if not args.do_stage1: 
            model = GraphGPTModel(args, llm_path, graph_embedding=graph_embedding, stage="matching", inductive_embs=graph_embeddings)
            params = [p for _, p in model.named_parameters() if p.requires_grad]
            optimizer = torch.optim.AdamW([{'params': params, 'lr': args.s2_lr, 'weight_decay': args.wd}])
            trainable_params, all_params = model.print_trainable_params()
            print(f"Trainable params {trainable_params} || all params {all_params} || trainable% {100 * trainable_params / all_params:.5f}")
        else:
            print(f"Directly load stage1's pretrained graph projector layer")
            
        train_dataset = GraphInstructionTuningDataset(graph_data=graph_data, maximum_neighbors=args.s2_num_neighbors, dataset_name=args.dataset, data_type="train", re_split=args.re_split, split_data=train_data if is_inductive else None)
        # For validation and test, use clean data in transductive setting
        val_test_graph_data = clean_graph_data if args.prompt is not None and not is_inductive else graph_data
        val_dataset = GraphInstructionTuningDataset(graph_data=val_test_graph_data, maximum_neighbors=args.s2_num_neighbors, dataset_name=args.dataset, data_type="val", re_split=args.re_split, split_data=val_data if is_inductive else None)
        test_dataset = GraphInstructionTuningDataset(graph_data=val_test_graph_data, maximum_neighbors=args.s2_num_neighbors, dataset_name=args.dataset, data_type="test", re_split=args.re_split, split_data=test_data if is_inductive else None)
        
        train_loader = DataLoader(train_dataset, batch_size=args.s2_batch_size, drop_last=False, pin_memory=True, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.s2_batch_size*2, drop_last=False, pin_memory=True, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.s2_batch_size*3, drop_last=False, pin_memory=True, shuffle=False)
        
        num_training_steps = args.s2_epoch * len(train_loader)
        progress_bar = tqdm(range(num_training_steps))
        
        best_val_loss = float('inf')
        model.model.gradient_checkpointing_enable()
        
        llm_config_str = f"{args.llm}_Epoch{args.s2_epoch}"
        
        # Check if stage2 model exists
        stage2_model_path = os.path.join(save_path, f"{llm_config_str}_best.pth")
        should_train_stage2 = True
        if os.path.exists(stage2_model_path):
            print(f"Loading existing Stage 2 model from {stage2_model_path}")
            model = reload_best_model(model, save_path, llm_config_str)
            print("Skipping Stage 2 training as model already exists")
            should_train_stage2 = False

        if should_train_stage2:
            print("Training Stage 2...")
            for epoch in range(args.s2_epoch):
                model.train()
            
                # Set phase for training in inductive learning
                if is_inductive:
                    model.set_phase('train')
            
                epoch_loss, accum_loss = 0.0, 0.0 
            
                for step, batch in enumerate(train_loader):
                    optimizer.zero_grad()
                    loss = model(batch) 
                    loss.backward()
                
                    clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
                    optimizer.step() 
                    epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()

                    if (step + 1) % 4 == 0:
                        lr = optimizer.param_groups[0]["lr"]
                        accum_loss = 0.

                    progress_bar.update(1)
       
                print(f"[TRAIN] Epoch {epoch+1}|{args.s2_epoch}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader):.5f}")
            
                val_loss = 0.0 
                model.eval()
                
                # Set phase for validation in inductive learning
                if is_inductive:
                    model.set_phase('val')
                
                with torch.no_grad():
                    for step, batch in enumerate(val_loader):
                        loss = model(batch)
                        val_loss += loss.item()
                print(f"[VAL] Epoch {epoch+1}|{args.s2_epoch}: Val Loss {val_loss / len(val_loader):.5f}")
                
                if val_loss < best_val_loss:
                    best_val_loss = val_loss 
                    best_epoch = epoch
                    save_checkpoint(model, epoch, save_path, llm_config_str, is_best=True)
                
                if epoch - best_epoch >= args.s2_patience: 
                    print(f"[TRAIN] Earlt stop at epoch {epoch+1}")
                    break 
                model = reload_best_model(model, save_path, llm_config_str)
     
    train_secs = time.time() - st_time
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    
    model.eval()
    
    # Set phase for testing in inductive learning
    if is_inductive:
        model.set_phase('test')
    
    # Create prediction directory and ensure it exists
    prediction_dir = os.path.join(args.output_dir, "prediction")
    os.makedirs(prediction_dir, exist_ok=True)
    re_split_str = '_s' if args.re_split else ''
    path = os.path.join(prediction_dir, f"{args.dataset}_{args.llm}{re_split_str}_seed{args.seed}.json")
    print(f"\n[Prediction] Write predictions on {path} ...")
    progress_bar = tqdm(range(len(test_loader)))
    
    pred_labels, gt_labels = [], []
    st_time = time.time()
    valid_labels = classes[args.dataset]
    
    try:
        with open(path, 'w') as file:
            for step, batch in enumerate(test_loader):
                with torch.no_grad():
                    id_list, predictions = model.inference(batch)
                    
                    for node_idx, llm_pred in zip(id_list, predictions):
                        node_idx = node_idx.item()
                        pred_label = llm_pred[:llm_pred.index("</s>")] if "</s>" in llm_pred else llm_pred
                        
                        # Fix for inductive setting: proper node ID mapping
                        if is_inductive:
                            original_node_idx = test_data.node_ids[node_idx].item()
                            gt_label = classes[args.dataset][test_data.y[node_idx].item()]
                        else:
                            original_node_idx = node_idx
                            gt_label = classes[args.dataset][graph_data.y[node_idx].item()]
                        
                        write_obj = {
                            "id": original_node_idx,
                            "pred": llm_pred,
                            "ground-truth": gt_label
                        }
                        pred_label = pred_label if pred_label in valid_labels else UNKNOW
                        pred_labels.append(pred_label) 
                        gt_labels.append(write_obj["ground-truth"])
                        file.write(json.dumps(write_obj) + "\n")
                        file.flush()
                    progress_bar.update(1)
    except Exception as e:
        print(f"Error during prediction or writing results: {e}")
    
    inference_secs = time.time() - st_time
    
    # Compute metrics
    acc, macrof1, weightf1 = compute_acc_and_f1(pred_labels, gt_labels)
    
    print(f"Accuracy {acc:.3f}  Macro F1-Score {macrof1:.3f}  Weight F1-Score {weightf1:.3f}")
    print('\n## Finishing Time:', get_cur_time(), flush=True)
    print('= ' * 20)
    print("Done!")               
    