import json
import numpy as np
import argparse
import os
import logging
from typing import List, Dict, Any, Optional
import re
import matplotlib.pyplot as plt

# Set up logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

def load_json_data(file_path: str) -> List[Dict[str, Any]]:
    """
    Load data from a JSON file.
    """
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

def calculate_mean_completion_tokens(data: List[Dict[str, Any]], skipped: bool = None) -> float:
    """
    Calculate the mean completion token length from the data.
    
    Args:
        data: List of data items
        skipped: If True, only include skipped items. If False, only include non-skipped items.
                If None, include all items.
    """
    if skipped is not None:
        filtered_data = [item for item in data if item.get('skipped', False) == skipped]
    else:
        filtered_data = data
    
    completion_tokens = [item.get('completion_tokens', 0) for item in filtered_data]
    
    if not completion_tokens:
        return 0.0
    
    return np.mean(completion_tokens)


def extract_answer(completion: str) -> Optional[float]:
    """
    Extracts the final answer from the LLM's output.
    """
    matches = re.finditer(r"<answer>(.*?)<\/answer>", completion, re.DOTALL)
    matches_list = list(matches)
    
    if not matches_list:
        return None
    
    # Get the last match
    last_match = matches_list[-1]
    answer_text = last_match.group(1).strip()
    
    try:
        prediction = float(answer_text)
    except:
        return None 
    
    if prediction < 0 or prediction > 1:
        return None 
    
    return prediction

def calculate_brier_score(data: List[Dict[str, Any]]) -> float:
    """
    Calculate the Brier score for the predictions.
    
    Brier score = (prediction - resolution)^2
    Lower is better.
    
    Args:
        data: List of data items containing responses and resolutions
    
    Returns:
        float: Average Brier score
    """
    scores = []
    
    for item in data:
        response = item.get('response', '')
        resolution = item.get('resolution', None)
        
        if response and resolution is not None:
            prediction = extract_answer(response)
            
            if prediction is not None:
                # Calculate Brier score: (prediction - resolution)^2
                brier_score = (prediction - float(resolution)) ** 2
                scores.append(brier_score)
                
    if not scores:
        return 0.0
    
    avg_brier_score = np.mean(scores)
    return avg_brier_score

def calculate_log_odds_score(data: List[Dict[str, Any]]) -> float:
    """
    Calculate the log odds score for the predictions.
    
    Log odds score:
    - If resolution=1: log(prediction)
    - If resolution=0: log(1-prediction)
    
    Higher is better.
    
    Args:
        data: List of data items containing responses and resolutions
    
    Returns:
        float: Average log odds score
    """
    scores = []
    
    for item in data:
        response = item.get('response', '')
        resolution = item.get('resolution', None)
        
        if response and resolution is not None:
            prediction = extract_answer(response)
            
            if prediction is not None:
                # Clip prediction to avoid log(0)
                prediction = max(min(prediction, 0.9999), 0.0001)
                
                # Calculate log odds score
                if float(resolution) == 1:
                    log_odds = np.log(prediction)
                else:
                    log_odds = np.log(1 - prediction)
                
                scores.append(log_odds)
                
    if not scores:
        return 0.0
    
    avg_log_odds = np.mean(scores)
    return avg_log_odds

def calculate_accuracy(data: List[Dict[str, Any]]) -> float:
    """
    Calculate the binary accuracy for the predictions.
    
    A prediction is considered correct if:
    - For resolution=1, prediction >= 0.5
    - For resolution=0, prediction < 0.5
    
    Args:
        data: List of data items containing responses and resolutions
    
    Returns:
        float: Accuracy (percentage of correct predictions)
    """
    correct = 0
    total = 0
    
    for item in data:
        response = item.get('response', '')
        resolution = item.get('resolution', None)
        
        if response and resolution is not None:
            prediction = extract_answer(response)
            
            if prediction is not None:
                # Convert prediction to binary (0 or 1)
                binary_prediction = 1 if prediction >= 0.5 else 0
                # Check if prediction matches resolution
                if binary_prediction == float(resolution):
                    correct += 1
                total += 1
                
    if total == 0:
        return 0.0
    
    return correct / total

def get_last_chars_of_skipped(data: List[Dict[str, Any]], num_chars: int = 500) -> List[str]:
    """
    Get the last N characters of responses for skipped items.
    """
    skipped_items = [idx for idx, item in enumerate(data) if item.get('skipped', False)]
    last_chars = []
    
    for idx in skipped_items:
        response = data[idx].get('response', '')
        if response:
            # Include the last characters of the response along with other relevant fields
            item_info = {
                'last_chars': response[-num_chars:],
                'final_answer': data[idx].get('final_answer', None),
                'skipped': data[idx].get('skipped', False),
                'idx': data[idx].get('idx', None),
                'generation_idx': data[idx].get('generation_idx', None)
            }
            last_chars.append(item_info)
            # Print item info 
            logger.info(f"Skipped item {idx}: {item_info}")
            # Custom extract answer
            answer = extract_answer(response)
            logger.info(f"Answer: {answer}")
        else:
            last_chars.append('')
    
    return last_chars

def plot_completion_tokens_histogram(data: List[Dict[str, Any]], output_path: str = None) -> None:
    """
    Plot a histogram of completion token lengths.
    
    Args:
        data: List of data items
        output_path: Path to save the plot. If None, the plot will be displayed.
    """
    completion_tokens = [item.get('completion_tokens', 0) for item in data]
    
    plt.figure(figsize=(10, 6))
    plt.hist(completion_tokens, bins=30, alpha=0.7, color='blue')
    plt.xlabel('Completion Tokens')
    plt.ylabel('Frequency')
    plt.title('Distribution of Response Lengths (Completion Tokens)')
    plt.grid(True, alpha=0.3)
    
    if output_path:
        plt.savefig(output_path)
        logger.info(f"Histogram saved to {output_path}")
    else:
        plt.show()
    
    plt.close()

def analyze_outputs(file_path: str) -> None:
    """
    Analyze the outputs from a JSON file and print statistics.
    """
    logger.info(f"Loading data from {file_path}")
    data = load_json_data(file_path)
    
    # Calculate overall mean completion token length
    mean_tokens = calculate_mean_completion_tokens(data)
    logger.info(f"Overall mean completion token length: {mean_tokens:.2f}")
    
    # Calculate mean completion token length for non-skipped items
    mean_tokens_not_skipped = calculate_mean_completion_tokens(data, skipped=False)
    logger.info(f"Mean completion token length (skipped=False): {mean_tokens_not_skipped:.2f}")
    
    # Calculate mean completion token length for skipped items
    mean_tokens_skipped = calculate_mean_completion_tokens(data, skipped=True)
    logger.info(f"Mean completion token length (skipped=True): {mean_tokens_skipped:.2f}")
    
    # Calculate Brier score
    brier_score = calculate_brier_score(data)
    logger.info(f"Average Brier score: {brier_score:.4f}")
    
    # Calculate log odds score
    log_odds_score = calculate_log_odds_score(data)
    logger.info(f"Average log odds score: {log_odds_score:.4f}")
    
    # Calculate accuracy
    accuracy = calculate_accuracy(data)
    logger.info(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    # Get last 100 characters of skipped responses
    # last_chars = get_last_chars_of_skipped(data)
    # if not last_chars:
    #     logger.info("No skipped responses found.")
    
    # Plot histogram of completion tokens
    histogram_path = file_path.replace('.json', '') + '_token_histogram.png'
    # plot_completion_tokens_histogram(data, histogram_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Analyze model outputs from JSON files")
    parser.add_argument("--file_path", type=str, required=True, help="Path to the JSON file containing model outputs")
    
    args = parser.parse_args()
    
    if not os.path.exists(args.file_path):
        logger.error(f"File not found: {args.file_path}")
        exit(1)
    
    analyze_outputs(args.file_path)
