#!/usr/bin/env python3
"""
LLM Testing Script - For testing JGA metrics across different domains

This script is based on the train.py architecture, using LLM models for dialogue state tracking testing,
supporting JGA metric testing for five domains: hotel, train, restaurant, attraction, taxi.
"""

import argparse
import json
import os
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, T5Tokenizer, T5ForConditionalGeneration, AutoConfig
from tqdm import tqdm
import numpy as np
from datetime import datetime
from copy_data_loader import prepare_data, EXPERIMENT_DOMAINS
from evaluate import evaluate_metrics, get_slot_information


def parse_args():
    """
    Parse command line arguments
    
    Returns:
        args: Argument object containing all test configurations
    """
    parser = argparse.ArgumentParser(description="LLM dialogue state tracking test parameters")
    
    # Model configuration
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-3B-Instruct", 
                       help="LLM model name to use")
    parser.add_argument("--ckpt", type=str, default="save/models/small", 
                       help="Pre-trained model checkpoint path")
    
    # Dataset configuration
    parser.add_argument("--dataset", type=str, default="multiwoz", help="Dataset name")
    parser.add_argument("--data_path", type=str, default="data", help="Dataset path")
    parser.add_argument("--test_batch_size", type=int, default=8, help="Test batch size")
    parser.add_argument("--train_batch_size", type=int, default=16, help="Training batch size")
    parser.add_argument("--dev_batch_size", type=int, default=8, help="Validation batch size")
    parser.add_argument("--slot_lang", type=str, default="question", 
                       help="Slot description type (none/human/naive/value/question/slottype)")
    parser.add_argument("--max_size", type=int, default=250, help="Maximum token count for model input")
    parser.add_argument("--test_mode", type=str, default="only", 
                       help="Test mode (only/except)")
    
    # System configuration
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
    parser.add_argument("--seed", type=int, default=3407, help="Random seed")
    parser.add_argument("--worker_number", type=int, default=8, 
                       help="Number of CPU threads for data loader")
    
    # Test configuration
    parser.add_argument("--domains", type=str, nargs="+", 
                       default=["hotel", "train", "restaurant", "attraction", "taxi"],
                       help="List of domains to test")
    parser.add_argument("--output_dir", type=str, default="llm_results", 
                       help="Test results output directory")
    parser.add_argument("--save_predictions", action="store_true", 
                       help="Whether to save prediction results")
    
    args = parser.parse_args()
    return args


def setup_model_and_tokenizer(model_name: str, device: str):
    """
    Set up model and tokenizer
    
    Parameters:
        model_name: Model name
        device: Device name
        
    Returns:
        model: Loaded model
        tokenizer: Loaded tokenizer
    """
    print(f"Loading model: {model_name}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        padding_side="left"
    )
    
    # Ensure tokenizer has pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    model.eval()
    return model, tokenizer


def create_prompt(dialog_history: str, slot_description: str, tokenizer=None) -> str:
    """
    Create test prompt using slot_description from data_detail
    
    Parameters:
        dialog_history: Dialogue history
        slot_description: Slot description (from data_detail)
        tokenizer: Tokenizer (optional)
        
    Returns:
        prompt: Constructed prompt text
    """
    # Use slot_description from data_detail, maintaining format consistent with train.py
    sep_token = tokenizer.sep_token if tokenizer and hasattr(tokenizer, 'sep_token') else ' '
    prompt = f"{dialog_history} {sep_token} {slot_description}"
    return prompt


def generate_predictions(model, tokenizer, test_data: List[Dict], 
                        device: str, batch_size: int = 8) -> Dict[str, Dict]:
    """
    Generate prediction results using slot_description from data_detail
    
    Parameters:
        model: LLM model
        tokenizer: Tokenizer
        test_data: Test data
        device: Device
        batch_size: Batch size
        
    Returns:
        predictions: Prediction results dictionary
    """
    predictions = {}
    
    print(f"Starting prediction generation, total {len(test_data)} samples")
    
    # Process in batches
    for i in tqdm(range(0, len(test_data), batch_size), desc="Prediction generation progress", unit="batch"):
        batch = test_data[i:i + batch_size]
        
        # Construct prompts
        prompts = []
        for item in batch:
            prompt = create_prompt(
                dialog_history=item["dialog_history"],
                slot_description=item["slot_description"],  # Use slot_description from data_detail
                tokenizer=tokenizer
            )
            print(f"dialog:{item['dialog_history']}")
            print(f"Question:{item['slot_description']}")
            prompts.append(prompt)
        
        # Tokenization
        inputs = tokenizer(
            prompts,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(device)
        
        # Generate predictions
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=50,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                temperature=None,  # Explicitly set to None to avoid warnings
                top_p=None,        # Explicitly set to None to avoid warnings
                top_k=None         # Explicitly set to None to avoid warnings
            )
        
        # Decode prediction results
        generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # Extract prediction values
        for j, (item, generated_text) in enumerate(zip(batch, generated_texts)):
            # Extract answer from generated text
            answer = extract_answer_from_text(generated_text, prompts[j])

            print(f"Answer:{answer}")
            
            # Construct prediction results
            dial_id = item["ID"]
            turn_id = str(item["turn_id"])
            
            if dial_id not in predictions:
                predictions[dial_id] = {
                    "turns": {}
                }
            
            if turn_id not in predictions[dial_id]["turns"]:
                predictions[dial_id]["turns"][turn_id] = {
                    "turn_belief": [],
                    "pred_belief": []
                }
            
            # Add ground truth values
            for belief in item["turn_belief"]:
                if belief not in predictions[dial_id]["turns"][turn_id]["turn_belief"]:
                    predictions[dial_id]["turns"][turn_id]["turn_belief"].append(belief)
            
            # Add predicted values
            slot_key = f"{item['slot_domain']}-{item['slot_text']}"
            pred_belief = f"{slot_key}-{answer}"
            if pred_belief not in predictions[dial_id]["turns"][turn_id]["pred_belief"]:
                predictions[dial_id]["turns"][turn_id]["pred_belief"].append(pred_belief)
    
    return predictions


def extract_answer_from_text(generated_text: str, prompt: str) -> str:
    """
    Extract answer from generated text
    
    Parameters:
        generated_text: Generated complete text
        prompt: Original prompt
        
    Returns:
        answer: Extracted answer
    """
    # Remove prompt part
    answer_text = generated_text.replace(prompt, "").strip()
    
    # Simple answer cleaning
    answer_text = answer_text.strip()
    if not answer_text or answer_text.lower() in ["none", "unknown", "n/a"]:
        return "none"
    
    # Take first line as answer
    answer_text = answer_text.split("\n")[0].strip()
    
    # Remove punctuation
    answer_text = answer_text.rstrip(".!?,;:")
    
    return answer_text.lower()


def evaluate_domain(model, tokenizer, args, domain: str) -> Dict[str, float]:
    """
    Evaluate single domain using copy_data_loader.py to load MultiWOZ test data
    
    Parameters:
        model: LLM model
        tokenizer: Tokenizer
        args: Arguments
        domain: Domain name
        
    Returns:
        metrics: Evaluation metrics dictionary
    """
    print(f"\nEvaluating domain: {domain}")
    
    # Only support MultiWOZ dataset
    if args.dataset != "multiwoz":
        print(f"Error: Currently only supports MultiWOZ dataset, does not support {args.dataset}")
        return {"joint_acc": 0.0, "f1": 0.0, "turn_acc": 0.0}
    
    # Set test parameters - using copy_data_loader.py format
    class TestArgs:
        def __init__(self):
            self.dataset = args.dataset
            self.except_domain = domain if args.test_mode == "except" else "none"
            self.only_domain = domain if args.test_mode == "only" else "none"
            self.slot_lang = args.slot_lang
            self.max_size = args.max_size
            self.fewshot = 0.0
            self.seed = args.seed
            self.train_batch_size = args.train_batch_size
            self.dev_batch_size = args.dev_batch_size
            self.test_batch_size = args.test_batch_size
            self.worker_number = 4
            self.gpu_id = args.gpu_id
            self.fix_label = True  # Add missing fix_label parameter
    
    test_args = TestArgs()
    
    # Prepare test data - using copy_data_loader.py's prepare_data function
    print("Preparing test data...")
    try:
        # Import copy_data_loader
        import copy_data_loader
        train_loader, dev_loader, test_loader, ALL_SLOTS, global_tokens, Dev_DESC, Test_DESC = copy_data_loader.prepare_data(
            test_args,
            tokenizer
        )
        
        # Get test data from test_loader
        test_data = []
        for batch in test_loader:
            # Extract batch data
            batch_size = len(batch["dialog_history"])
            for i in range(batch_size):
                try:
                    item = {
                        'ID': batch["ID"][i] if "ID" in batch else f'test_{i}',
                        'domain': batch["slot_domain"][i] if "slot_domain" in batch else domain,
                        'turn_id': batch["turn_id"][i] if "turn_id" in batch else 0,
                        'dialog_history': batch["dialog_history"][i],
                        'slot_text': batch["slot_text"][i],
                        'value_text': batch["value_text"][i] if "value_text" in batch else 'none',
                        'turn_belief': batch["turn_belief"][i] if "turn_belief" in batch else [],
                        'slot_description': batch["slot_description"][i] if "slot_description" in batch else batch["slot_lang"][i] if "slot_lang" in batch else 'none',
                        'slot_domain': batch["slot_domain"][i] if "slot_domain" in batch else domain
                    }
                except Exception as e:
                    print(f"Error processing sample {i}: {e}")
                    print(f"Available fields: {list(batch.keys())}")
                    print(f"slot_domain field type: {type(batch.get('slot_domain', 'not exists'))}")
                    if "slot_domain" in batch:
                        print(f"slot_domain length: {len(batch['slot_domain'])}")
                        print(f"batch_size: {batch_size}")
                    raise
                test_data.append(item)
        
        print(f"Test data quantity: {len(test_data)} samples")
        
        if not test_data:
            print(f"Warning: Domain {domain} has no test data")
            return {"joint_acc": 0.0, "f1": 0.0, "turn_acc": 0.0}
        
        # Check data format to ensure it contains slot_description
        if test_data and 'slot_description' not in test_data[0]:
            print(f"Error: Test data missing slot_description field")
            return {"joint_acc": 0.0, "f1": 0.0, "turn_acc": 0.0}
        
        # Print some samples to verify slot_description
        print(f"Sample slot_description: {test_data[0]['slot_description'] if test_data else 'no data'}")
        
        # Generate predictions
        device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"
        predictions = generate_predictions(
            model, tokenizer, test_data, device, args.test_batch_size
        )
        
        # Evaluation metrics - using ALL_SLOTS
        joint_acc, f1_score, turn_acc = evaluate_metrics(predictions, ALL_SLOTS)
        
        metrics = {
            "joint_acc": joint_acc,
            "f1": f1_score,
            "turn_acc": turn_acc,
            "total_turns": len(predictions)
        }
        
        print(f"Domain {domain} evaluation results:")
        print(f"  Joint Accuracy: {joint_acc:.4f}")
        print(f"  F1 Score: {f1_score:.4f}")
        print(f"  Turn Accuracy: {turn_acc:.4f}")
        
        return metrics, predictions
        
    except Exception as e:
        print(f"Error loading data: {e}")
        import traceback
        traceback.print_exc()  # Print complete error stack
        return {"joint_acc": 0.0, "f1": 0.0, "turn_acc": 0.0}, []


def main():
    """
    Main function
    """
    args = parse_args()
    
    # Set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # Create device
    device = f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Set up model and tokenizer
    model, tokenizer = setup_model_and_tokenizer(args.model_name, device)
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Store all results
    all_results = {}
    all_predictions = {}
    
    # Evaluate each domain
    from tqdm import tqdm
    for domain in tqdm(args.domains, desc="Domain evaluation progress", unit="domain"):
        if domain not in EXPERIMENT_DOMAINS:
            print(f"Skipping invalid domain: {domain}")
            continue
        
        try:
            metrics, predictions = evaluate_domain(model, tokenizer, args, domain)
            all_results[domain] = metrics
            all_predictions[domain] = predictions
            
            # Save intermediate results
            if args.save_predictions:
                domain_output_file = os.path.join(args.output_dir, f"{domain}_predictions.json")
                with open(domain_output_file, 'w', encoding='utf-8') as f:
                    json.dump(predictions, f, indent=2, ensure_ascii=False)
        
        except Exception as e:
            print(f"Error evaluating domain {domain}: {str(e)}")
            all_results[domain] = {"joint_acc": 0.0, "f1": 0.0, "turn_acc": 0.0, "error": str(e)}
    
    # Print summary results
    print("\n" + "="*60)
    print("Summary of evaluation results for all domains:")
    print("="*60)
    
    print(f"{'Domain':<12} {'Joint Acc':<12} {'F1 Score':<12} {'Turn Acc':<12}")
    print("-" * 60)
    
    total_joint_acc = 0
    total_f1 = 0
    total_turn_acc = 0
    valid_domains = 0
    
    for domain, metrics in all_results.items():
        joint_acc = metrics.get("joint_acc", 0.0)
        f1 = metrics.get("f1", 0.0)
        turn_acc = metrics.get("turn_acc", 0.0)
        
        print(f"{domain:<12} {joint_acc:<12.4f} {f1:<12.4f} {turn_acc:<12.4f}")
        
        total_joint_acc += joint_acc
        total_f1 += f1
        total_turn_acc += turn_acc
        valid_domains += 1
    
    # Calculate averages
    if valid_domains > 0:
        avg_joint_acc = total_joint_acc / valid_domains
        avg_f1 = total_f1 / valid_domains
        avg_turn_acc = total_turn_acc / valid_domains
        
        print("-" * 60)
        print(f"{'Average':<12} {avg_joint_acc:<12.4f} {avg_f1:<12.4f} {avg_turn_acc:<12.4f}")
    
    print("="*60)
    
    # Save final results
    final_results = {
        "model_name": args.model_name,
        "results": all_results,
        "summary": {
            "average_joint_acc": avg_joint_acc if valid_domains > 0 else 0,
            "average_f1": avg_f1 if valid_domains > 0 else 0,
            "average_turn_acc": avg_turn_acc if valid_domains > 0 else 0,
            "valid_domains": valid_domains
        }
    }
    
    results_file = os.path.join(args.output_dir, "final_results.json")
    with open(results_file, 'w', encoding='utf-8') as f:
        json.dump(final_results, f, indent=2, ensure_ascii=False)
    
    print(f"\nResults saved to: {results_file}")
    
    return all_results


if __name__ == "__main__":
    main()