from dataclasses import dataclass, field
import os
from typing import Literal, Optional
from src.models.predict import predict_dataset
from src.data.dataset_loader import load_by_name
from src.utils import parse_args
from collections import Counter
import pandas as pd
import re
from datetime import datetime
import json
from pathlib import Path
from typing import Optional

def clean_output(text: str) -> str:
    """
    Remove <think>...</think> content and strip remaining text of spaces, dots, etc.
    """
    if text is None:
        return ""

    # Remove <think>...</think> content (case insensitive, multiline)
    cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.IGNORECASE | re.DOTALL)
    
    # Strip whitespace and common punctuation
    cleaned = cleaned.strip().strip('.,!?;:')
    
    return cleaned

def create_timestamped_dir(
    eval_type: str = "ri",
    dataset_name: str = None,
    model_config: str = None,
    results_base: str = "results",
    logs_base: str = "logs",
):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create short names for dataset and model
    dataset_short = ""
    if dataset_name:
        # Extract short name from dataset (e.g., "mix" -> "mix", "simpleqa" -> "sq")
        if dataset_name.lower() == "mix":
            dataset_short = "mix"
        elif dataset_name.lower() == "simpleqa":
            dataset_short = "sq"
        else:
            # Take first 3 chars of dataset name
            dataset_short = dataset_name.replace("/", "_")[:3]
    
    model_short = ""
    if model_config:
        # Extract meaningful short name from model config
        if "qwen" in model_config.lower():
            model_short = "qwen"
        elif "gemini" in model_config.lower():
            model_short = "gem"
        elif "gpt" in model_config.lower():
            model_short = "gpt"
        elif "mistral" in model_config.lower():
            model_short = "mis"
        elif "gemma" in model_config.lower():
            model_short = "gma"
        else:
            # Take first 3 chars of model config
            model_short = model_config[:3]
    
    # Build directory name with dataset and model info
    dir_suffix = f"{eval_type}"
    if dataset_short:
        dir_suffix += f"-{dataset_short}"
    if model_short:
        dir_suffix += f"-{model_short}"
    
    results_dir = f"{results_base}/{timestamp}-{dir_suffix}"
    logs_dir = f"{logs_base}/{timestamp}-{dir_suffix}"
    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(logs_dir, exist_ok=True)
    return results_dir, logs_dir

def extract_answer(predicted_answer: str) -> str:
    # Check for various forms of unanswered
    if predicted_answer == None or "UNANSWERED" in predicted_answer:
        return "UNANSWERED"
        
    # Try to find answer between <answer> tags, case insensitive
    match = re.search(r'<answer>(.*?)</answer>', predicted_answer, re.IGNORECASE | re.DOTALL)
    if match:
        return match.group(1).strip()

    return predicted_answer

def extract_grade(grade: str) -> str:
    """Take first line, uppercase, and return if it's A, B, or C otherwise return X"""
    if grade is None:
        return "X"
    grade = grade.split('\n')[0].upper().strip()
    if grade in ['A', 'B', 'C']:
        return grade
    return grade

def write_json_logs(dataset, prompt_template_path: str, prompt_index: int, predicted_dataset, graded_dataset, logs_dir: str, args):
    """
    Write JSON logs in apricot format for a specific prompt.
    """
    # Read the prompt template
    with open(prompt_template_path, 'r') as f:
        prompt_template = f.read()
    
    json_data = []
    
    for i, (row, pred_row, grade_row) in enumerate(zip(dataset, predicted_dataset, graded_dataset)):
        # Reconstruct the full question by replacing {question} in template
        full_question = prompt_template.replace('{question}', row['question'])
        
        # Convert grade to binary label (A=1, B/C=0)
        grade = clean_output(grade_row['grade'])
        label = 1 if grade == 'A' else 0
        
        json_entry = {
            "id": i,
            "question": full_question,
            "answer": row['answer'],
            "model_output": pred_row['predicted_answer_raw'],
            "label": label
        }
        json_data.append(json_entry)
    
    # Save JSON file
    prompt_name = prompt_template_path.split('/')[-1].split('.')[0]
    filename = f"{logs_dir}/prompt_{prompt_index}_{prompt_name}_{args.dataset_name.replace('/', '_')}_{args.model_config}_{args.max_samples}_samples.json"
    
    with open(filename, 'w') as f:
        json.dump(json_data, f, indent=2)
    
    print(f"JSON logs saved to {filename}")
    return filename

    

def summarize_and_dump_metrics(results_df, args, prompt_name=None, results_dir=None):
    # Optionally limit rows for summary
    if args.max_samples:
        results_df = results_df.head(args.max_samples)

    # Detect number of prompts
    grade_cols = [col for col in results_df.columns if re.match(r'grade\d+', col)]
    num_prompts = len(grade_cols)

    # Compute per-prompt basic metrics needed for downstream refusal-index computation
    prompt_metrics = []
    for i in range(num_prompts):
        def value_ratio(column, field, df):
            counts = df[column].value_counts()
            return (counts.get(field, 0)) / len(df) if len(df) > 0 else 0.0

        rej = value_ratio(f'grade{i}', 'C', results_df)  # refusal rate
        acc = value_ratio(f'grade{i}', 'A', results_df)  # accuracy with refusals counted as wrong
        mean = acc  # keep a mean field for downstream consumers expecting it

        prompt_data = {
            "prompt_index": i,
            "rejection_rate": rej,
            "accuracy": acc,
            "mean": mean,
            "avg_char_count": float(results_df[f'char_count{i}'].mean()),
            "max_char_count": int(results_df[f'char_count{i}'].max()),
            "min_char_count": int(results_df[f'char_count{i}'].min()),
        }
        prompt_metrics.append(prompt_data)

    # Prepare JSON payload without deprecated metric fields
    json_results = {
        "settings": {
            "dataset_name": args.dataset_name,
            "model_name": args.model_name,
            "max_samples": args.max_samples,
            "temperature": args.temperature,
            "top_p": args.top_p,
            "prompt_name": prompt_name,
            "num_proc": args.num_proc,
            "verbose": args.verbose,
            "retry_attempts": args.retry_attempts,
            "suffix": args.suffix,
            "max_tokens": args.max_tokens,
            "max_thinking_tokens": args.max_thinking_tokens,
        },
        "metrics": [
            {
                "prompts": prompt_metrics
            }
        ],
    }

    # Save the metrics JSON
    if results_dir is None:
        results_dir = "results"
    json_path = f"{results_dir}/{args.dataset_name.replace('/', '_')}_{args.model_config}_{args.temperature}_{args.max_samples}_metrics.json"
    with open(json_path, 'w') as f:
        json.dump(json_results, f, indent=2)
    print(f"Saved metrics to {json_path}")

def load_model_config(config_name: str) -> dict:
    """Load model configuration from json file."""
    config_path = Path("model_configs") / f"{config_name}.json"
    if not config_path.exists():
        raise ValueError(f"Model config file {config_path} does not exist")
    with open(config_path) as f:
        return json.load(f)

@dataclass
class Arguments:
    dataset_name: str = "mix"
    max_samples: int = 6000
    verbose: bool = False
    num_proc: int = 50
    model_config: str = field(default=None)  # Required - must be specified explicitly
    retry_attempts: int = 5
    max_thinking_tokens: Optional[int] = None  # Maximum tokens to allocate for reasoning
    google_api_key: Optional[str] = field(default_factory=lambda: os.environ.get("GOOGLE_API_KEY"))
    use_same_model_for_grading: bool = False  # If True, use same model config for grading
    grader_backend: Literal["openai", "google", "vllm", "vllm_offline"] = "openai"  # Backend for grading
    grader_model: str = "google/gemini-2.0-flash-lite-001"  # Model for grading
    
    # CLI override arguments (these will override model config if provided)
    inference_backend: Optional[Literal["openai", "google", "vllm", "vllm_offline"]] = None
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    top_p: Optional[float] = None
    model_name: Optional[str] = None
    suffix: Optional[str] = None
    # Output directory bases
    results_base: Optional[str] = "results"
    logs_base: Optional[str] = "logs"

    def __post_init__(self):
        # Ensure model_config is provided
        if self.model_config is None:
            raise ValueError("model_config must be specified explicitly")
        
        # Store CLI overrides before loading config
        cli_overrides = {
            'inference_backend': self.inference_backend,
            'temperature': self.temperature,
            'max_tokens': self.max_tokens,
            'top_p': self.top_p,
            'model_name': self.model_name,
            'suffix': self.suffix
        }
        
        # Load model config and update instance attributes
        config = load_model_config(self.model_config)
        for key, value in config.items():
            setattr(self, key, value)
        
        # Apply CLI overrides (only if they were explicitly provided)
        for key, value in cli_overrides.items():
            if value is not None:
                setattr(self, key, value)
                if self.verbose:
                    print(f"CLI override: {key} = {value}")
            
        # Check if HuggingFace backend is being used
        if hasattr(self, 'inference_backend') and self.inference_backend == "huggingface":
            raise ValueError("HuggingFace backend is not supported in this evaluation. Please use 'openai', 'google', 'vllm', or 'vllm_offline' backends instead.")
            
        # If using same model for grading, override grader settings
        if self.use_same_model_for_grading:
            self.grader_backend = self.inference_backend
            self.grader_model = self.model_name
            
            # Also check grader backend
            if self.grader_backend == "huggingface":
                raise ValueError("HuggingFace backend is not supported for grading.")

def main():
    args = parse_args(Arguments)
    print(f"Arguments: {args}")

    instruction_prompt_paths: list[str] = ["prompts/PROMPT_A.txt", "prompts/PROMPT_C.txt"]
    test_set = load_by_name(args.dataset_name, max_samples=args.max_samples)

    # Initialize DataFrame with question and answer columns
    results_df = pd.DataFrame({
        'question': test_set['question'],
        'answer': test_set['answer']
    })

    results_dir, logs_dir = create_timestamped_dir(
        "ri",
        args.dataset_name,
        args.model_config,
        results_base=args.results_base or "results",
        logs_base=args.logs_base or "logs",
    )

    # Process each prompt and add predictions and grades
    for i, instruction_prompt_path in enumerate(instruction_prompt_paths):
        print(f"Evaluating {instruction_prompt_path}")
        predicted_dataset = predict_dataset(
            test_set, 
            model_name=args.model_name, 
            prompt_template_path=instruction_prompt_path, 
            num_proc=args.num_proc, 
            output_column="predicted_answer_raw", 
            temperature=args.temperature, 
            max_tokens=args.max_tokens,
            suffix=args.suffix,
            top_p=args.top_p,
            max_thinking_tokens=args.max_thinking_tokens,
            inference_backend=args.inference_backend,
            google_api_key=args.google_api_key,
        )
        predicted_dataset = predicted_dataset.map(lambda x: {"predicted_answer": extract_answer(x["predicted_answer_raw"])})
        graded_dataset = predict_dataset(
            predicted_dataset, 
            model_name=args.grader_model, 
            prompt_template_path="prompts/GRADER.txt", 
            output_column="grade", 
            num_proc=args.num_proc, 
            max_tokens=10,
            inference_backend=args.grader_backend,
            google_api_key=args.google_api_key,
            suffix=None,  # Explicitly set suffix to None for grading
        )
        
        # Add predictions and grades to DataFrame
        results_df[f'pred_raw{i}'] = predicted_dataset['predicted_answer_raw']
        results_df[f'pred{i}'] = predicted_dataset['predicted_answer']
        results_df[f'grade{i}'] = [extract_grade(clean_output(grade)) for grade in graded_dataset['grade']]
        
        # Calculate and store character counts
        results_df[f'char_count{i}'] = results_df[f'pred_raw{i}'].str.len()
        avg_chars = results_df[f'char_count{i}'].mean()
        print(f"\nCharacter count statistics for {instruction_prompt_path}:")
        print(f"Average characters: {avg_chars:.2f}")
        print(f"Max characters: {results_df[f'char_count{i}'].max()}")
        print(f"Min characters: {results_df[f'char_count{i}'].min()}")
        
        # Display first 3 raw predictions
        print(f"First 3 raw predictions for {instruction_prompt_path}:")
        for j in range(3):
            print(f"Example {j+1}: {results_df[f'pred_raw{i}'][j]}")

        # Print grade counts for this prompt
        grade_counts = Counter(results_df[f'grade{i}'])
        print(f"\nGrade counts for prompt {i}:")
        print(grade_counts)

        # Save results to csv after each prompt (overwrite the same file)
        filename = f"{logs_dir}/{args.dataset_name.replace('/', '_')}_{args.model_config}.csv"
        results_df.to_csv(filename, index=False)
        print(f"Results saved to {filename}")
        
        # Save JSON logs in apricot format for this prompt
        write_json_logs(test_set, instruction_prompt_path, i, predicted_dataset, graded_dataset, logs_dir, args)

    # Calculate and print overall statistics
    print("\nOverall character count statistics:")
    for i in range(len(instruction_prompt_paths)):
        avg_chars = results_df[f'char_count{i}'].mean()
        print(f"Prompt {i} ({instruction_prompt_paths[i].split('/')[-1]}): {avg_chars:.2f} characters on average")

    # Get the first prompt name for the plot title/filename
    prompt_name = instruction_prompt_paths[0].split("/")[-1].split(".")[0] if instruction_prompt_paths else None

    # Save arguments to JSON
    args_dict = {
        "dataset_name": args.dataset_name,
        "max_samples": args.max_samples,
        "verbose": args.verbose,
        "num_proc": args.num_proc,
        "model_config": args.model_config,
        "model_name": args.model_name,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "retry_attempts": args.retry_attempts,
        "suffix": args.suffix,
        "max_tokens": args.max_tokens,
        "max_thinking_tokens": args.max_thinking_tokens,
        "use_same_model_for_grading": args.use_same_model_for_grading,
        "grader_backend": args.grader_backend,
        "grader_model": args.grader_model,
        "instruction_prompt_paths": instruction_prompt_paths
    }
    args_path = f"{results_dir}/args.json"
    with open(args_path, 'w') as f:
        json.dump(args_dict, f, indent=2)
    print(f"Arguments saved to {args_path}")

    # Save per-prompt summary metrics for downstream refusal-index analysis
    summarize_and_dump_metrics(results_df, args, prompt_name, results_dir)


if __name__ == "__main__":
    main()
