import argparse
import os
import sys
import json
import csv
import time
import shutil
from dataclasses import dataclass, field
from typing import Optional, List

import torch
import wandb
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

sys.path.append("../..")
from common import set_seed, load_graph_dataset, compute_acc_and_f1
from common import IGNORE_INDEX, UNKNOW
from common import CLASSES as classes 
from common import MODEL_PATHs as llm_paths
from common import get_prompt_template, get_available_prompt_types, get_supported_datasets, requires_neighbor_info, get_classes_str
from common import prepare_edge_list


@dataclass
class ModelArguments:
    llm: Optional[str] = field(default="Mistral-7B")
    lora_adapter_path: Optional[str] = field(default=None, metadata={"help": "Path to the LoRA adapter"})
    merged_model_path: Optional[str] = field(default=None, metadata={"help": "Path to save/load merged model"})


@dataclass
class DataArguments:
    dataset: str = field(default="cora", metadata={"help": "Dataset name"})
    prompt_type: str = field(default="instruction_tuning", metadata={"help": "Type of prompt template to use"})
    max_txt_length: int = field(default=128, metadata={"help": "Maximum length of query prompt"})
    max_origin_txt_length: int = field(default=128, metadata={"help": "Maximum length of node's original text"})
    max_ans_length: int = field(default=16, metadata={"help": "Maximum length of answer"})
    re_split: int = field(default=0, metadata={"help": "Whether to use re-split"})
    maximum_neighbor: int = field(default=8, metadata={"help": "Maximum number of neighbors to include"})


@dataclass
class InferenceArguments:
    batch_size: int = field(default=16, metadata={"help": "Batch size for inference"})
    num_gpus: int = field(default=2, metadata={"help": "Number of GPUs for tensor parallelism"})
    temperature: float = field(default=0.0, metadata={"help": "Sampling temperature"})
    gpu_memory_utilization: float = field(default=0.9, metadata={"help": "GPU memory utilization for vLLM"})


def get_training_output_path(dataset, llm, prompt_type, seed, re_split=0):
    """Get the training output path that matches train.py"""
    alg_dir = f"../../results/InstructionTuning"
    re_split_str = '_s' if re_split else ''
    save_dir = f"{alg_dir}/output/{dataset}{re_split_str}_{llm}_{prompt_type}"
    return save_dir


def get_lora_adapter_path(dataset, llm, prompt_type, seed, re_split=0):
    """Get the LoRA adapter path from training output"""
    save_dir = get_training_output_path(dataset, llm, prompt_type, seed, re_split)
    return f"{save_dir}/lora_adapter"


def get_merged_model_path(dataset, llm, prompt_type, seed, re_split=0):
    """Get the merged model path under /data/LLMBackbone/Merged_models"""
    re_split_str = '_s' if re_split else ''
    merged_dir = f"/data/LLMBackbone/Merged_models"
    return f"{merged_dir}/{dataset}{re_split_str}_{llm}_{prompt_type}_seed{seed}"


def prepare_graph_instruction_tuning_data(graph_data, data_type="test", dataset_name="cora", prompt_type="instruction_tuning", maximum_neighbor=8):
    """Prepare test data for inference with support for different prompt types"""
    focus_mask = {"train": graph_data.train_mask, "val": graph_data.val_mask, "test": graph_data.test_mask}[data_type]
    focus_nodes = focus_mask.nonzero(as_tuple=False).squeeze(-1).detach().cpu().numpy()
    # Ensure focus_nodes is always a list
    if focus_nodes.ndim == 0:
        focus_nodes = [focus_nodes.item()]
    else:
        focus_nodes = focus_nodes.tolist()
    
    train_val_nodes = (graph_data.train_mask | graph_data.val_mask).nonzero(as_tuple=False).squeeze(-1).detach().cpu().numpy()
    # Ensure train_val_nodes is always a list
    if train_val_nodes.ndim == 0:
        train_val_nodes = [train_val_nodes.item()]
    else:
        train_val_nodes = train_val_nodes.tolist()
    
    # Prepare edge list and degree information if neighbor information is needed
    edge_list = None
    node_degrees = None
    if requires_neighbor_info(prompt_type):
        edge_list = prepare_edge_list(graph_data.edge_index.cpu(), graph_data.num_nodes)
        # Calculate node degrees
        edge_index = graph_data.edge_index.cpu()
        node_degrees = torch.zeros(graph_data.num_nodes, dtype=torch.long)
        for i in range(graph_data.num_nodes):
            node_degrees[i] = (edge_index[0] == i).sum() + (edge_index[1] == i).sum()
    
    data_contents = []
    for node_id in focus_nodes: 
        origin_txt = graph_data.raw_texts[node_id]
        label = classes[dataset_name][graph_data.y[node_id].item()]
        
        data_item = {
            "id": node_id, 
            "input": origin_txt, 
            "output": label
        }
        
        # Add neighbor information if needed
        if requires_neighbor_info(prompt_type) and edge_list is not None:
            neighbors = edge_list[node_id]
            # Select top-k neighbors by degree
            if len(neighbors) > maximum_neighbor:
                neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in neighbors]
                neighbor_degrees.sort(key=lambda x: x[1], reverse=True)  # Sort by degree descending
                neighbors = [neigh for neigh, _ in neighbor_degrees[:maximum_neighbor]]
            
            neighbor_texts = [graph_data.raw_texts[neigh] for neigh in neighbors]
            
            # For neighbor_label prompt type, we need labels for all neighbors
            if prompt_type == "neighbor_label":
                neighbor_labels = []
                for neigh in neighbors:
                    if neigh in train_val_nodes:
                        neighbor_labels.append(classes[dataset_name][graph_data.y[neigh].item()])
                    else:
                        neighbor_labels.append("unknown")
            elif prompt_type == "neighbor":
                # For other neighbor-based prompts, only include labels for train/val nodes
                neighbor_labels = [classes[dataset_name][graph_data.y[neigh].item()] for neigh in neighbors if neigh in train_val_nodes]
            
            data_item["neighbor_texts"] = neighbor_texts
            data_item["neighbor_labels"] = neighbor_labels
        
        data_contents.append(data_item)
    
    return data_contents


def prepare_vllm_inference_data(batch_data, tokenizer, dataset_name, prompt_type, max_txt_length=128, max_origin_txt_length=128):
    """Prepare data for vLLM inference using chat template and new prompt system"""
    prompts = []

    # Get the prompt template for the specified prompt type and dataset
    prompt_template = get_prompt_template(prompt_type, dataset_name)
    classes_str = get_classes_str(dataset_name)
    
    for sample in batch_data:
        # Truncate the origin text
        tokenized_origin_txt = tokenizer(sample["input"], add_special_tokens=False)
        origin_txt_ids = tokenized_origin_txt.input_ids[:max_origin_txt_length] 
        truncated_origin_txt = tokenizer.decode(origin_txt_ids, skip_special_tokens=True)
        
        # Prepare the user content based on prompt type
        if requires_neighbor_info(prompt_type):
            # For neighbor-based prompts, prepare neighbor text with individual truncation
            neighbor_texts = sample.get("neighbor_texts", [])
            neighbor_labels = sample.get("neighbor_labels", [])
            truncated_neighbor_texts = []
            
            # Truncate each neighbor text individually
            for i, neighbor_text in enumerate(neighbor_texts):
                tokenized_neighbor_txt = tokenizer(neighbor_text, add_special_tokens=False)
                neighbor_txt_ids = tokenized_neighbor_txt.input_ids[:max_origin_txt_length]
                truncated_neighbor_txt = tokenizer.decode(neighbor_txt_ids, skip_special_tokens=True)
                
                # For neighbor_label prompt type, include label information
                if prompt_type == "neighbor_label" and i < len(neighbor_labels):
                    label = neighbor_labels[i] if neighbor_labels[i] else "unknown"
                    truncated_neighbor_texts.append(f"node text: {truncated_neighbor_txt}, label: {label}")
                elif prompt_type == "neighbor":
                    truncated_neighbor_texts.append(truncated_neighbor_txt)
            
            # Format neighbor information
            neighbor_str = "\n".join([f"Neighbor {i+1}: {text}" for i, text in enumerate(truncated_neighbor_texts)])
            
            # Format the prompt template with both origin and neighbor text
            user_content = prompt_template.format(origin_text=truncated_origin_txt, neighbor_text=neighbor_str, classes=classes_str)
        else:
            # For single-node prompts, format with just the node content
            user_content = prompt_template.format(node=truncated_origin_txt, classes=classes_str)
        
        # Create conversation using chat template
        messages = [
            {"role": "user", "content": user_content}
        ]
        
        # Format the input using chat template with generation prompt
        formatted_prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        
        prompts.append(formatted_prompt)
    
    return prompts


def merge_and_save_model(base_model_path, lora_adapter_path, merged_model_path):
    """Merge LoRA adapter with base model and save"""
    print(f"Loading LoRA model from {lora_adapter_path}")
    print(f"Base model: {base_model_path}")
    
    # Load the PEFT model
    model = AutoPeftModelForCausalLM.from_pretrained(
        lora_adapter_path,
        torch_dtype=torch.bfloat16,
        device_map="cpu"  # Load on CPU first for merging
    )
    
    # Merge LoRA weights with base model
    print("Merging LoRA weights with base model...")
    merged_model = model.merge_and_unload()
    
    # Save the merged model
    print(f"Saving merged model to {merged_model_path}")
    os.makedirs(merged_model_path, exist_ok=True)
    merged_model.save_pretrained(merged_model_path)
    
    # Also save the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    tokenizer.save_pretrained(merged_model_path)
    
    print(f"Merged model saved successfully to {merged_model_path}")
    return merged_model_path


def run_vllm_inference(model_path, test_contents, dataset_name, prompt_type, inference_args, data_args):
    """Run inference using vLLM"""
    print(f"Initializing vLLM with model: {model_path}")
    print(f"Tensor parallel size: {inference_args.num_gpus}")
    
    # Initialize vLLM model
    vllm_model = LLM(
        model=model_path,
        tensor_parallel_size=inference_args.num_gpus,
        gpu_memory_utilization=inference_args.gpu_memory_utilization,
        trust_remote_code=True,
        dtype=torch.bfloat16,
    )
    
    # Load tokenizer for stop tokens and chat template
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # Setup tokenizer (same as in train.py)
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        else:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    # Set up sampling parameters with proper stop tokens
    stop_tokens = []
    if tokenizer.eos_token:
        stop_tokens.append(tokenizer.eos_token)
        
    sampling_params = SamplingParams(
        temperature=inference_args.temperature, 
        max_tokens=data_args.max_ans_length,
        stop=stop_tokens,
    )
    
    # Prepare all prompts using chat template
    all_prompts = prepare_vllm_inference_data(
        test_contents, tokenizer, dataset_name, prompt_type,
        max_txt_length=data_args.max_txt_length, 
        max_origin_txt_length=data_args.max_origin_txt_length
    )
    
    print(f"Running inference on {len(all_prompts)} samples...")
    
    # Generate responses using vLLM
    outputs = vllm_model.generate(all_prompts, sampling_params)
    
    # Process outputs (same logic as in train.py)
    pred_labels, gt_labels = [], []
    results = []
    
    for idx, output in enumerate(outputs):
        generated_text = output.outputs[0].text
        label, node_id = test_contents[idx]["output"], test_contents[idx]["id"]
        
        pred_label = generated_text.strip()
        
        # Remove quotes and take first word
        pred_label = pred_label.strip().strip('"').strip("'").strip()
        
        # Validate prediction
        pred_label = pred_label if pred_label in classes[dataset_name] else UNKNOW
        
        pred_labels.append(pred_label)
        gt_labels.append(label)
        
        results.append({
            "idx": node_id, 
            "ground-truth": label,
            "pred": pred_label,
            "raw_output": generated_text,
            "cleaned_pred": generated_text.strip()
        })
    
    return pred_labels, gt_labels, results


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--llm', type=str, default="Mistral-7B")
    parser.add_argument('--dataset', type=str, default='cora')
    parser.add_argument('--prompt_type', type=str, default='instruction_tuning', 
                       help=f'Type of prompt template to use. Available types: {", ".join(get_available_prompt_types())}')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--max_txt_length', type=int, default=128)
    parser.add_argument('--max_origin_txt_length', type=int, default=128)
    parser.add_argument('--max_ans_length', type=int, default=16)
    parser.add_argument('--maximum_neighbor', type=int, default=8)
    parser.add_argument('--re_split', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_gpus', type=int, default=2)
    parser.add_argument('--temperature', type=float, default=0.0)
    parser.add_argument('--gpu_memory_utilization', type=float, default=0.9)
    parser.add_argument('--force_merge', action='store_true', help='Force re-merge even if merged model exists')
    parser.add_argument('--save_merged', action='store_true', help='Save merged model permanently, otherwise delete after inference')
    parser.add_argument('--use_wandb', action='store_true', help='Use wandb for logging')
    
    args = parser.parse_args()
    
    # Validate prompt_type and dataset combination
    available_prompt_types = get_available_prompt_types()
    if args.prompt_type not in available_prompt_types:
        raise ValueError(f"Invalid prompt_type: {args.prompt_type}. Available types: {available_prompt_types}")
    
    supported_datasets = get_supported_datasets(args.prompt_type)
    if args.dataset not in supported_datasets:
        raise ValueError(f"Dataset {args.dataset} not supported for prompt_type {args.prompt_type}. Supported datasets: {supported_datasets}")
    
    print(f"Using prompt_type: {args.prompt_type} for dataset: {args.dataset}")
    
    # Create dataclass instances
    model_args = ModelArguments(
        llm=args.llm,
        lora_adapter_path=None,  # Will be auto-determined
        merged_model_path=None   # Will be auto-determined
    )
    data_args = DataArguments(
        dataset=args.dataset,
        prompt_type=args.prompt_type,
        max_txt_length=args.max_txt_length,
        max_origin_txt_length=args.max_origin_txt_length,
        max_ans_length=args.max_ans_length,
        re_split=args.re_split,
        maximum_neighbor=args.maximum_neighbor
    )
    inference_args = InferenceArguments(
        batch_size=args.batch_size,
        num_gpus=args.num_gpus,
        temperature=args.temperature,
        gpu_memory_utilization=args.gpu_memory_utilization
    )
    
    set_seed(args.seed)
    
    # Initialize wandb
    if args.use_wandb:
        wandb.init(
            project="llm-node-classification_inference",
            name=f"{args.dataset}_{args.llm}_{args.prompt_type}_seed{args.seed}_vllm",
            config=vars(args)
        )
    
    # Auto-determine paths based on train.py structure
    lora_adapter_path = get_lora_adapter_path(args.dataset, args.llm, args.prompt_type, args.seed, args.re_split)
    merged_model_path = get_merged_model_path(args.dataset, args.llm, args.prompt_type, args.seed, args.re_split)
    
    print(f"Auto-determined LoRA adapter path: {lora_adapter_path}")
    print(f"Auto-determined merged model path: {merged_model_path}")
    
    # Check if LoRA adapter exists
    if not os.path.exists(lora_adapter_path):
        raise FileNotFoundError(f"LoRA adapter not found at {lora_adapter_path}. Please run training first.")
    
    # Setup paths
    base_model_path = llm_paths[args.llm]
    
    # Load tokenizer to check chat template availability
    print("Checking chat template availability...")
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    
    # Setup tokenizer (same as in train.py)
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        else:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    # Check if we need to merge the model
    need_merge = args.force_merge or not os.path.exists(merged_model_path)
    
    if need_merge:
        print("Merging LoRA adapter with base model...")
        merge_and_save_model(base_model_path, lora_adapter_path, merged_model_path)
    else:
        print(f"Using existing merged model at {merged_model_path}")
    
    # Load test data
    print(f"Loading dataset: {args.dataset}")
    graph_data = load_graph_dataset(dataset_name=args.dataset, device="cpu", re_split=args.re_split)
    test_contents = prepare_graph_instruction_tuning_data(graph_data, "test", args.dataset, args.prompt_type, args.maximum_neighbor)
    print(f"Test dataset size: {len(test_contents)}")
    
    # Run inference
    st_time = time.time()
    pred_labels, gt_labels, results = run_vllm_inference(
        merged_model_path, test_contents, args.dataset, args.prompt_type, inference_args, data_args
    )
    inference_secs = time.time() - st_time
    
    # Compute metrics
    acc, macro_f1, weight_f1 = compute_acc_and_f1(pred_labels, gt_labels)
    print(f"Accuracy: {acc:.4f}")
    print(f"Macro F1-Score: {macro_f1:.4f}")
    print(f"Weighted F1-Score: {weight_f1:.4f}")
    print(f"Inference time: {inference_secs:.2f} seconds")
    
    # Save results
    alg_dir = f"../../results/InstructionTuning"
    write_dir = f"{alg_dir}/prediction"
    os.makedirs(write_dir, exist_ok=True)
    
    re_split_str = '_s' if args.re_split else ''
    result_file = f"{write_dir}/{args.dataset}_{args.llm}_{args.prompt_type}{re_split_str}_seed{args.seed}_vllm.json"
    
    with open(result_file, "w") as f:
        for result in results:
            f.write(json.dumps(result) + "\n")
    
    print(f"Results saved to {result_file}")
    
    # Save summary
    summary_file = f"{alg_dir}/summary_vllm.csv"
    file_exists = os.path.exists(summary_file)
    
    with open(summary_file, 'a', newline='') as file:
        writer = csv.writer(file)
        if not file_exists:
            # Write header
            writer.writerow([
                "Dataset", "Model", "Prompt_Type", "Accuracy", "Macro_F1", "Weighted_F1", 
                "Split_Type", "Seed", "Batch_Size", "Inference_Seconds"
            ])
        
        writer.writerow([
            args.dataset, args.llm, args.prompt_type, acc, macro_f1, weight_f1, 
            "Semi" if not args.re_split else "Supervised", f"Seed-{args.seed}", 
            f"Batch Size-{args.batch_size}", f"Inference Seconds-{inference_secs:.2f}"
        ])
    
    print(f"Summary saved to {summary_file}")
    
    # Log to wandb
    if args.use_wandb:
        wandb.log({
            "accuracy": acc,
            "macro_f1": macro_f1,
            "weighted_f1": weight_f1,
            "inference_time_seconds": inference_secs,
            "test_samples": len(test_contents)
        })
        wandb.finish()
    
    # Clean up merged model if not saving
    if not args.save_merged and need_merge:
        print(f"Cleaning up merged model at {merged_model_path}")
        try:
            shutil.rmtree(merged_model_path)
            print("Merged model deleted to save space")
        except Exception as e:
            print(f"Warning: Could not delete merged model: {e}")
    elif args.save_merged:
        print(f"Merged model saved permanently at {merged_model_path}")


if __name__ == "__main__":
    main() 