import torch
import evaluate
import numpy as np
import re
import json
import os
from tqdm import tqdm

from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import PeftModel, PeftConfig

# -------------------------------------------------------------------
# 1. Specify model paths
# -------------------------------------------------------------------
BASE_MODEL_ID = "google/gemma-2-2b" 
PEFT_LORA_DIR = "gemma-2-2b_final_model_shallow" 

# -------------------------------------------------------------------
# 2. Load the base model (pre-finetuned)
# -------------------------------------------------------------------
print("Loading base model...")
tokenizer_pre = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
model_pre = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer_pre.pad_token = tokenizer_pre.eos_token

# -------------------------------------------------------------------
# 3. Load the fine-tuned model (post-finetuned)
# -------------------------------------------------------------------
print("Loading fine-tuned model...")
tokenizer_ft = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
base_model_ft = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer_ft.pad_token = tokenizer_ft.eos_token

print("Merging PEFT adapter with base model...")
peft_config = PeftConfig.from_pretrained(PEFT_LORA_DIR)
model_ft = PeftModel.from_pretrained(base_model_ft, PEFT_LORA_DIR)

# -------------------------------------------------------------------
# 4. Define generation function with task-specific parameters
# -------------------------------------------------------------------
def generate_text_safely(model, tokenizer, prompt, max_new_tokens=128, task_type="default"):
    """Generate text with task-specific parameters"""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        # Use different parameters based on task type
        if task_type == "math":
            # For math problems: Use greedy decoding with more tokens
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                num_beams=1,
                pad_token_id=tokenizer.pad_token_id,
                temperature=1.0
            )
        elif task_type in ["summary", "sql"]:
            # For summarization and SQL: Use beam search
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                num_beams=4,  # Use beam search for better quality
                pad_token_id=tokenizer.pad_token_id,
                temperature=1.0
            )
        else:
            # Default parameters for other tasks
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                num_beams=1,
                pad_token_id=tokenizer.pad_token_id,
                temperature=1.0
            )
    
    # Get only the generated text (excluding the prompt)
    generated_ids = output_ids[0][inputs.input_ids.shape[1]:]
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    # Return the full text (prompt + response)
    return prompt + generated_text

# ----------------------------------------------------------
# 5. Evaluate on SamSum
# ----------------------------------------------------------
def evaluate_samsum(model, tokenizer, split="test", max_samples=50):
    """Evaluate on SamSum dataset with ROUGE metrics"""
    print(f"Evaluating {split} split of SamSum dataset (max {max_samples} samples)...")
    samsum = load_dataset("samsum")
    rouge_metric = evaluate.load("rouge")
    data = samsum[split]
    
    predictions, references = [], []
    for i, sample in enumerate(data):
        if i >= max_samples:
            break
        if i % 5 == 0:  # Print progress every 5 samples
            print(f"Processing SamSum sample {i+1}/{min(max_samples, len(data))}")
            
        dialogue = sample["dialogue"]
        reference_summary = sample["summary"]
        
        # Generate summary with model
        try:
            # Tokenize input for generation
            inputs = tokenizer(f"Summarize this conversation:\n{dialogue}\n\nSummary:", 
                              return_tensors="pt").to(model.device)
            
            # Generate summary
            with torch.no_grad():
                output_ids = model.generate(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    max_new_tokens=100,
                    num_beams=4,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id
                )
            
            # Decode only the generated part (exclude input prompt)
            generated_ids = output_ids[0][inputs.input_ids.shape[1]:]
            generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
            
            # Clean up generated summary if needed
            if "Summary:" in generated_text:
                pred_summary = generated_text.split("Summary:", 1)[1].strip()
            else:
                pred_summary = generated_text.strip()
                
            predictions.append(pred_summary)
            references.append(reference_summary)
        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            continue
    
    if not predictions:
        print("No valid predictions were generated")
        return {
            "rouge1": 0.0,
            "rouge2": 0.0,
            "rougeL": 0.0,
            "rougeLsum": 0.0
        }
        
    # Calculate all ROUGE metrics
    results = rouge_metric.compute(
        predictions=predictions, 
        references=references,
        use_stemmer=True
    )
    
    # Print detailed results
    print(f"ROUGE-1: {results['rouge1']:.4f}")
    print(f"ROUGE-2: {results['rouge2']:.4f}")
    print(f"ROUGE-L: {results['rougeL']:.4f}")
    print(f"ROUGE-Lsum: {results['rougeLsum']:.4f}")
    
    return results


# ----------------------------------------------------------
# 6. Evaluate on gsm8k
# ----------------------------------------------------------
def solve_math_problem(model, tokenizer, question, max_new_tokens=300):
    # Encourage step-by-step reasoning
    prompt = f"Solve this step by step:\\n{question}\\n\\nThink through each step carefully. Your final answer is:"
    return generate_text_safely(model, tokenizer, prompt, max_new_tokens=max_new_tokens, task_type="math")

def extract_answer(text):
    # Common pattern in GSM8K answers: "The answer is X" or "X"
    match = re.search(r"(?:The answer is|the answer is|answer is|answer)[\s=]*(\d+)", text)
    if match:
        return match.group(1)
    
    # Look for the last number in the text
    numbers = re.findall(r"(\d+)", text)
    if numbers:
        return numbers[-1]
    
    return None

def evaluate_gsm8k(model, tokenizer, max_samples=None, batch_size=4):
    correct = 0
    total = 0
    
    # Limit evaluation to max_samples if specified
    gsm8k = load_dataset("gsm8k", "main")
    if max_samples is not None:
        eval_dataset = gsm8k["test"].select(range(min(max_samples, len(gsm8k["test"]))))
    else:
        eval_dataset = gsm8k["test"]
    
    # Use the dataset's built-in batching mechanism
    for i in range(0, len(eval_dataset), batch_size):
        batch_data = eval_dataset[i:i+batch_size]
        questions = batch_data["question"]
        answers = batch_data["answer"]
        
        # Extract correct answers
        correct_answers = []
        for answer in answers:
            # GSM8K format typically has "#### X" at the end
            answer_part = answer.split("####")[-1].strip()
            # Extract only the digits
            numerical_answer = re.sub(r"[^\d]", "", answer_part)
            correct_answers.append(numerical_answer)
        
        # Generate responses from model
        inputs = tokenizer(questions, padding=True, padding_side='left', return_tensors="pt").to(model.device)
        with torch.no_grad():  # No need to track gradients for inference
            outputs = model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=512,
                num_beams=4,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id
            )
        
        responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # Process responses
        for response, correct_answer in zip(responses, correct_answers):
            predicted_answer = extract_answer(response)
            
            if predicted_answer == correct_answer:
                correct += 1
            total += 1
        
        # Print progress
        if (i + batch_size) % 20 == 0 or (i + batch_size) >= len(eval_dataset):
            print(f"Progress: {total}/{len(eval_dataset)}, Current accuracy: {correct/total:.4f}")
    
    accuracy = correct / total
    return accuracy

# ----------------------------------------------------------
# 7. Evaluate on b-mc2/sql-create-context
# ----------------------------------------------------------
# Load dataset function that handles 'answer' field
def load_dataset_with_answer(file_path="sql_create_context_v4.json"):
    """
    Load the dataset directly from a local JSON file.
    Specifically handles 'answer' field as SQL query.
    """
    print(f"Loading dataset from local file: {file_path}")
    try:
        # Check if file exists
        if not os.path.exists(file_path):
            print(f"ERROR: File '{file_path}' not found.")
            return []
            
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            print(f"Successfully loaded JSON data from {file_path}")
        
        # Return the list of examples directly
        print(f"Loaded {len(data)} examples")
        
        # Verify we can access the required fields
        if data and len(data) > 0:
            keys = data[0].keys()
            print(f"Example keys: {list(keys)}")
            
            if 'answer' in keys:
                # Map 'answer' to 'sql' for compatibility with your function
                for item in data:
                    item['sql'] = item['answer']
                print("Mapped 'answer' field to 'sql' for compatibility")
        
        return data
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return []

# Step 4: Set up prompt template for SQL generation
def create_prompt(question, context):
    return f"""
Given the following context and question, generate a SQL query that answers the question.

Context:
{context}

Question:
{question}

SQL Query:
"""

def generate_sql(model, tokenizer, question, context, max_new_tokens=200):
    # More specific prompt for SQL generation
    prompt = create_prompt(question, context)
    return generate_text_safely(model, tokenizer, prompt, max_new_tokens=max_new_tokens, task_type="sql")

# Normalize SQL statements
def normalize_sql(sql):
    """Normalize SQL for comparison"""
    if not sql:
        return ""
    
    try:
        # Remove excess whitespace
        sql = ' '.join(sql.lower().split())
        # Remove comments
        sql = re.sub(r'--.*?(\n|$)', '', sql)
        # Standardize spacing around operators and punctuation
        sql = re.sub(r'\s*([,(){}])\s*', r'\1', sql)
        return sql
    except Exception as e:
        return str(sql).lower().strip()


def token_match_score(pred, ref):
    """Calculate token-level match score"""
    pred_tokens = set(normalize_sql(pred).split())
    ref_tokens = set(normalize_sql(ref).split())
    
    if not ref_tokens:
        return 0.0
    
    common_tokens = pred_tokens.intersection(ref_tokens)
    return len(common_tokens) / len(ref_tokens)

# SQL evaluation function using token match score instead of exact match
def evaluate_sql(model, tokenizer, file_path="sql_create_context_v4.json", max_samples=50):
    print(f"Evaluating SQL-Create-Context dataset (max {max_samples} samples)...")
    
    try:
        # Load the dataset using our function that handles 'answer' field
        sql_dataset = load_dataset_with_answer(file_path)
        if not sql_dataset:
            print("Failed to load SQL dataset")
            return 0.0
    except Exception as e:
        print(f"Error loading SQL dataset: {e}")
        return 0.0
        
    token_scores = []
    total = 0
    
    # Create results storage for detailed analysis
    results = {
        "context": [],
        "reference_sql": [],
        "generated_sql": [],
        "token_match_score": []
    }
    
    # Process dataset with progress bar
    for i, sample in enumerate(tqdm(sql_dataset[:max_samples], desc="Evaluating SQL samples")):
        if i % 5 == 0:  # Print progress every 5 samples
            print(f"Processing SQL sample {i+1}/{min(max_samples, len(sql_dataset))}")
            
        # Get context and reference SQL
        question = sample['question']
        context = sample["context"]
        reference_sql = sample['sql'] # This is mapped from 'answer' in the dataset loading
        
        # Generate SQL with model
        try:
            generated = generate_sql(model, tokenizer, question, context)
            
            # Extract just the SQL part (after the prompt)
            if "SQL:" in generated:
                generated_sql = generated.split("SQL:", 1)[1].strip()
            else:
                generated_sql = generated.split("Write a SQL CREATE statement", 1)[-1].strip()
            
            # Calculate token match score
            score = token_match_score(generated_sql, reference_sql)
            token_scores.append(score)
            total += 1
            
            # Store results for analysis
            results["context"].append(context)
            results["reference_sql"].append(reference_sql)
            results["generated_sql"].append(generated_sql)
            results["token_match_score"].append(score)
            
        except Exception as e:
            print(f"Error processing SQL sample {i}: {e}")
            continue
    
    if total == 0:
        print("No valid SQL samples were processed")
        return 0.0
    
    # Calculate average token match score
    avg_token_score = sum(token_scores) / total
    print(f"Average Token Match Score: {avg_token_score:.4f}")
    
    return avg_token_score

# ----------------------------------------------------------
# 8. Run all evaluations
# ----------------------------------------------------------
# Update the safe_evaluate function to handle dictionary returns
def safe_evaluate(eval_func, model, tokenizer, **kwargs):
    """Run evaluation with error handling"""
    try:
        return eval_func(model, tokenizer, **kwargs)
    except Exception as e:
        print(f"Error during {eval_func.__name__}: {e}")
        import traceback
        traceback.print_exc()
        if eval_func.__name__ == "evaluate_samsum":
            return {
                "rouge1": 0.0,
                "rouge2": 0.0,
                "rougeL": 0.0,
                "rougeLsum": 0.0
            }
        return 0.0

print("\\n=== Starting evaluations ===\\n")

# Evaluate with increased sample size for more reliable results
max_samples = 1000  

print("\\n1. Evaluating SamSum (ROUGE-L)...")
samsum_pre = safe_evaluate(evaluate_samsum, model_pre, tokenizer_pre, max_samples=max_samples)
print(f"Base model SamSum ROUGE-L: {samsum_pre}")

samsum_post = safe_evaluate(evaluate_samsum, model_ft, tokenizer_ft, max_samples=max_samples)
print(f"Fine-tuned model SamSum ROUGE-L: {samsum_post}")

print("\\n2. Evaluating GSM8K (Accuracy)...")
gsm8k_pre = safe_evaluate(evaluate_gsm8k, model_pre, tokenizer_pre, max_samples=max_samples)
print(f"Base model GSM8K accuracy: {gsm8k_pre}")

gsm8k_post = safe_evaluate(evaluate_gsm8k, model_ft, tokenizer_ft, max_samples=max_samples)
print(f"Fine-tuned model GSM8K accuracy: {gsm8k_post}")

print("\\n3. Evaluating SQL (Exact Match Rate)...")
sql_pre = safe_evaluate(evaluate_sql, model_pre, tokenizer_pre, max_samples=max_samples)
print(f"Base model SQL exact match rate: {sql_pre}")

sql_post = safe_evaluate(evaluate_sql, model_ft, tokenizer_ft, max_samples=max_samples)
print(f"Fine-tuned model SQL exact match rate: {sql_post}")

# ----------------------------------------------------------
# 9. Show a table of results
# ----------------------------------------------------------
# Update the print_results_table function to show more detailed ROUGE metrics
def print_results_table(
    samsum_pre, samsum_post,
    gsm8k_pre, gsm8k_post,
    sql_pre, sql_post
):
    """Print a simple ASCII table summarizing the results."""
    print("\n" + "="*70)
    print(f"{'DATASET':25s} | {'PRE-FINETUNE':20s} | {'POST-FINETUNE':20s}")
    print("="*70)
    
    # Print all ROUGE metrics for SAMSum
    for metric in ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']:
        pre_value = samsum_pre.get(metric, 0.0)
        post_value = samsum_post.get(metric, 0.0)
        print(f"{'SAMSum (' + metric + ')':25s} | {pre_value:20.4f} | {post_value:20.4f}")
    
    # Other metrics
    print(f"{'GSM8K (Accuracy)':25s} | {gsm8k_pre:20.4f} | {gsm8k_post:20.4f}")
    print(f"{'SQL (Average Token Match Score)':25s} | {sql_pre:20.4f} | {sql_post:20.4f}")
    print("="*70 + "\n")

print("\n=== Final Results ===")
print_results_table(
    samsum_pre, samsum_post,
    gsm8k_pre, gsm8k_post,
    sql_pre, sql_post
)