import torch
import importlib
import re
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.data_utils import load_local_dataset
from tqdm import tqdm

# Task evaluation mapping table
TASK_EVAL_MAP = {
    'C-STANCE': 'eval_CStance',
    'FOMC': 'eval_FOMC', 
    'MeetingBank': 'eval_MeetingBank',
    'ScienceQA': 'eval_ScienceQA',
    'NumGLUE-cm': 'eval_NumGLUE_cm',
    'NumGLUE-ds': 'eval_NumGlUE_ds',
    '20Minuten': 'eval_20Minuten'
}

# Task-specific generation parameter configuration
TASK_GEN_CONFIG = {
    'C-STANCE': {'max_new_tokens': 20, 'temperature': 0.1, 'do_sample': False},
    'FOMC': {'max_new_tokens': 20, 'temperature': 0.1, 'do_sample': False},
    'MeetingBank': {'max_new_tokens': 150, 'temperature': 0.7, 'do_sample': True},
    'ScienceQA': {'max_new_tokens': 100, 'temperature': 0.5, 'do_sample': True},
    'NumGLUE-cm': {'max_new_tokens': 20, 'temperature': 0.1, 'do_sample': False},
    'NumGLUE-ds': {'max_new_tokens': 20, 'temperature': 0.1, 'do_sample': False},
    '20Minuten': {'max_new_tokens': 150, 'temperature': 0.7, 'do_sample': True}
}

def extract_answer(text, prompt):
    """Extract the answer part from the generated result"""
    # Remove the input prompt part
    if prompt in text:
        answer = text[len(prompt):].strip()
    else:
        answer = text.strip()
    return answer

def extract_first_option(text):
    """Extract options for multiple-choice questions (A, B, C, D)"""
    if not text:
        return ""
    
    # Try to find common answer patterns
    match = re.search(r'([A-D])[.、）\)、\s]', text[:20], re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Find patterns where answer is X
    match = re.search(r'答案是\s*([A-D])', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    # Extract the first appearing option letter
    match = re.search(r'([A-D])', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    
    return ""

def fallback_evaluate(predicted_sequences, ground_truths, is_multiple_choice=False):
    """
    Fallback evaluation method when evaluation module cannot be loaded
    """
    # Import local metrics module
    from trace_evaluation.metrics import caculate_accuracy, caculate_rouge
    
    if is_multiple_choice:
        # Multiple-choice evaluation
        return {"accuracy": caculate_accuracy(predicted_sequences, ground_truths)}
    else:
        # Generative evaluation
        return {"rouge-l": caculate_rouge(predicted_sequences, ground_truths)}

def evaluate_model_accuracy(model_path, dataset_name, device="auto", result_dir=None):
    """
    Evaluate model performance on specified dataset
    
    Args:
        model_path: Model path
        dataset_name: Dataset name
        device: Running device
        result_dir: Experiment result directory, save to this directory if provided
        
    Returns:
        dict: Dictionary containing evaluation results
    """
    import datetime
    import os
    
    print(f"\nEvaluating model {model_path} performance on {dataset_name} dataset")
    
    # Create evaluation result save folder - if result_dir is provided, save to that directory
    if result_dir:
        save_dir = os.path.join(result_dir, "evaluation_outputs")
    else:
        save_dir = os.path.join("./evaluation_results", "generation_outputs")
    
    os.makedirs(save_dir, exist_ok=True)
    
    # Generate unique filename (using model name, dataset name, and timestamp)
    model_name = os.path.basename(model_path)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    save_file = os.path.join(save_dir, f"{model_name}_{dataset_name}_{timestamp}.json")
    
    # Prefer to load eval.json for evaluation
    test_dataset = load_local_dataset(dataset_name, split='eval')
    if test_dataset is None:
        # Try to use test.json as alternative
        test_dataset = load_local_dataset(dataset_name, split='test')
        if test_dataset is None:
            return {"error": f"Unable to load eval, test, or validation split for dataset {dataset_name}"}
    
    # Ensure dataset has prompt and answer fields
    if "prompt" not in test_dataset.column_names or "answer" not in test_dataset.column_names:
        return {"error": f"Dataset {dataset_name} does not contain required prompt and answer fields"}
    
    # Load model and tokenizer
    try:
        print(f"Loading model: {model_path}")
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map=device,
            low_cpu_mem_usage=True
        )
        
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        if not tokenizer.pad_token:
            tokenizer.pad_token = tokenizer.eos_token
    except Exception as e:
        return {"error": f"Failed to load model: {e}"}
    
    # Prepare evaluation
    input_sequences = []
    prompts = []
    ground_truths = []
    
    # Extract test data
    for example in test_dataset:
        if "prompt" in example and "answer" in example:
            prompts.append(example["prompt"])
            ground_truths.append(example["answer"])
            # For evaluations that need input sequences (like 20Minuten)
            if "input" in example:
                input_sequences.append(example["input"])
            else:
                input_sequences.append(example.get("prompt", ""))
    
    # Determine task type and generation parameters
    sample_answers = [test_dataset[i]["answer"] for i in range(min(5, len(test_dataset)))]
    is_multiple_choice = all(len(ans.strip()) <= 3 for ans in sample_answers)
    
    # Get task-specific generation parameters, use defaults if none
    gen_kwargs = TASK_GEN_CONFIG.get(dataset_name, {})
    if not gen_kwargs:
        if is_multiple_choice:
            gen_kwargs = {
                "max_new_tokens": 20,
                "temperature": 0.1,
                "top_p": 0.9,
                "do_sample": False
            }
        else:
            gen_kwargs = {
                "max_new_tokens": 150,
                "temperature": 0.7,
                "top_p": 0.9,
                "do_sample": True
            }
    
    # Use model to generate answers
    predicted_sequences = []
    full_outputs = []  # Save full output for recording
    
    print(f"Starting to generate answers for {dataset_name}...")
    
    for prompt in tqdm(prompts):
        try:
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            
            # Generate response
            with torch.no_grad():
                outputs = model.generate(**inputs, **gen_kwargs)
            
            # Decode
            output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            full_outputs.append(output_text)  # Save full output
            
            # Extract answer part
            answer = extract_answer(output_text, prompt)
            predicted_sequences.append(answer)
            
        except Exception as e:
            print(f"Error generating response: {e}")
            predicted_sequences.append("")
            full_outputs.append("")  # Save empty string on error
    
    # Clean up GPU memory
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Find task's evaluation module
    eval_module_name = TASK_EVAL_MAP.get(dataset_name)
    eval_results = {}
    
    try:
        # Import corresponding evaluation module
        if eval_module_name:
            eval_module_path = f"trace_evaluation.eval_modules.{eval_module_name}" 
            module = importlib.import_module(eval_module_path)
            
            # Call corresponding evaluation function based on different evaluation modules
            if dataset_name == "20Minuten":
                # Needs input sequences
                eval_results = module.eval(input_sequences, predicted_sequences, ground_truths)
            else:
                # Only needs predicted sequences and ground truth labels
                eval_results = module.eval(predicted_sequences, ground_truths)
        else:
            print(f"Warning: Cannot find evaluation module for dataset {dataset_name}, using default evaluation")
            eval_results = fallback_evaluate(predicted_sequences, ground_truths, is_multiple_choice)
            
    except Exception as e:
        print(f"Evaluation error: {e}")
        import traceback
        traceback.print_exc()
        
        # Use fallback evaluation method
        print("Using fallback evaluation method...")
        eval_results = fallback_evaluate(predicted_sequences, ground_truths, is_multiple_choice)
    
    # Assemble data to save
    save_data = {
        "metadata": {
            "model_path": model_path,
            "model_name": model_name,
            "dataset_name": dataset_name,
            "timestamp": timestamp,
            "num_samples": len(prompts),
            "generation_parameters": gen_kwargs,
            "is_multiple_choice": is_multiple_choice
        },
        "results": eval_results,
        "samples": []
    }
    
    # Add sample data
    for i in range(len(prompts)):
        sample = {
            "prompt": prompts[i],
            "generated_answer": predicted_sequences[i],
            "reference_answer": ground_truths[i],
            "full_output": full_outputs[i]
        }
        save_data["samples"].append(sample)
    
    # Save evaluation data to JSON file
    try:
        with open(save_file, "w", encoding="utf-8") as f:
            json.dump(save_data, f, ensure_ascii=False, indent=2)
        print(f"Saved generated content and evaluation results to: {save_file}")
    except Exception as e:
        print(f"Error saving evaluation result file: {e}")
    
    print(f"{dataset_name} evaluation results: {json.dumps(eval_results, indent=2)}")
    return eval_results

def evaluate_model_loss(model_path, dataset_name, device="auto", result_dir=None):
    """
    Evaluate model's average loss (cross-entropy) on specified dataset
    """
    import datetime
    import os
    import numpy as np

    print(f"\nEvaluating model {model_path} loss on {dataset_name} dataset")

    # Create evaluation result save folder
    if result_dir:
        save_dir = os.path.join(result_dir, "loss_outputs")
    else:
        save_dir = os.path.join("./evaluation_results", "loss_outputs")
    os.makedirs(save_dir, exist_ok=True)

    # Generate unique filename
    model_name = os.path.basename(model_path)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    save_file = os.path.join(save_dir, f"{model_name}_{dataset_name}_loss_{timestamp}.json")

    # Load dataset
    test_dataset = load_local_dataset(dataset_name, split='eval')
    if test_dataset is None:
        test_dataset = load_local_dataset(dataset_name, split='test')
        if test_dataset is None:
            return {"error": f"Unable to load eval, test, or validation split for dataset {dataset_name}", "loss": 10.0}

    # Load model and tokenizer
    try:
        print(f"Loading model: {model_path}")
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map=device,
            low_cpu_mem_usage=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        if not tokenizer.pad_token:
            tokenizer.pad_token = tokenizer.eos_token
        model.eval()
    except Exception as e:
        return {"error": f"Failed to load model: {e}", "loss": 10.0}

    # Calculate loss for all samples
    losses = []
    
    for example in tqdm(test_dataset, desc="Calculating loss"):
        try:
            # Get prompt and answer
            prompt = example["prompt"]
            answer = example["answer"]
            
            # ===== Modify the following part to fix the error =====
            
            # To calculate language model loss, we need:
            # 1. Combine prompt+answer into one sequence
            # 2. Input the sequence to the model, set labels to the same as input
            # 3. This way the model will automatically calculate the cross-entropy loss for next token prediction
            
            # Encode input+target sequence
            combined_text = prompt + answer
            encodings = tokenizer(combined_text, return_tensors="pt", truncation=True, max_length=512)
            
            # Move encodings to correct device
            input_ids = encodings["input_ids"].to(model.device)
            attention_mask = encodings["attention_mask"].to(model.device)
            
            # Labels same as input (standard setting for autoregressive language models)
            labels = input_ids.clone()
            
            # Calculate loss - ensure labels parameter is provided
            with torch.no_grad():
                outputs = model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    labels=labels  # Provide labels to ensure loss is calculated
                )
                
                # Check if loss exists
                if outputs.loss is not None:
                    loss = outputs.loss.item()
                    losses.append(loss)
                else:
                    print(f"Warning: Loss for sample '{prompt[:30]}...' is None")
                    losses.append(5.0)  # Use default value
                
        except Exception as e:
            print(f"Error calculating loss: {e}")
            losses.append(5.0)  # Use default value on error
    
    # Calculate average loss
    avg_loss = float(np.mean(losses)) if losses else 10.0

    # Save loss results
    save_data = {
        "model_path": model_path,
        "model_name": model_name,
        "dataset_name": dataset_name,
        "timestamp": timestamp,
        "num_samples": len(test_dataset),
        "avg_loss": avg_loss,
        "individual_losses": losses[:10]  # Save only first 10 loss samples
    }
    
    try:
        with open(save_file, "w", encoding="utf-8") as f:
            json.dump(save_data, f, ensure_ascii=False, indent=2)
        print(f"Saved loss results to: {save_file}")
    except Exception as e:
        print(f"Error saving loss result file: {e}")

    # Clean up
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    print(f"{dataset_name} average loss: {avg_loss:.4f}")
    return {"loss": avg_loss}