import os
import json
import argparse
import logging
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import numpy as np

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

import scienceplots
plt.style.use('science')


def load_jsonl_file(file_path: str) -> List[Dict[str, Any]]:
    """
    Load data from a JSONL file.
    
    Args:
        file_path: Path to the JSONL file
        
    Returns:
        List of dictionaries containing the data
    """
    data = []
    # logger.info(f"Loading data from {file_path}")
    with open(file_path, 'r') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    # logger.info(f"Loaded {len(data)} items from {file_path}")
    return data

MODEL_COST = {
    "deepseek/deepseek-chat-v3-0324" : {
        "input": 0.30,
        "output": 0.88,
    },
    "qwen/qwen-2.5-7b-instruct" : {
        "input": 0.04,
        "output": 0.1,
    },
    "openai/o4-mini" : {
        "input": 1.1,
        "output": 4.4,
    },
    "openai/gpt-4o" : {
        "input": 2.5,
        "output": 10,
    },
    "meta-llama/llama-4-maverick" : {
        "input": 0.16,
        "output": 0.6,
    },
    "qwen/qwen3-32b" : {
        "input": 0.1,
        "output": 0.3,
    },
    "Qwen3-32B" : {
        "input": 0.1,
        "output": 0.3,
    },
}

MODEL_NAMES = {
    "deepseek/deepseek-chat-v3-0324" : "Deepseek-v3",
    "qwen/qwen-2.5-7b-instruct" : "Qwen3-4B", # "Qwen2.5-7B",
    "openai/o4-mini" : "o4-mini",
    "openai/gpt-4o" : "gpt-4o",
    "meta-llama/llama-4-maverick" : "Llama-4-Maverick",
    "qwen/qwen3-32b" : "Qwen3-32B",
    "Qwen3-32B" : "Qwen3-32B",
}

def save_jsonl_file(data: List[Dict[str, Any]], file_path: str) -> None:
    """
    Save data to a JSONL file.
    
    Args:
        data: List of dictionaries to save
        file_path: Path to save the JSONL file
    """
    with open(file_path, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')
    logger.info(f"Saved {len(data)} items to {file_path}")
    
def get_samples(
    input_dir: str,
) -> None:
    """
    Combine samples from different model outputs for annotation.
    
    Args:
        input_dir: Directory containing model output files
        filtered_ids_path: Path to file with filtered question IDs
        batch_size: Number of samples per model
    """
    
    models = [
        # "openai/gpt-4o"
        # "openai/o4-mini",
        # "deepseek/deepseek-r1",
        # "qwen/qwen3-14b"
        # "qwen/qwen3-32b"
        "qwen/qwen-2.5-7b-instruct",
        
        
        "deepseek/deepseek-chat-v3-0324",
        # "qwen/qwen-2.5-72b-instruct",
        
        # "meta-llama/llama-4-scout",
        # "meta-llama/llama-4-maverick",
        # "meta-llama/llama-3.1-8b-instruct",
        # "meta-llama/llama-3-8b-instruct",
        
        # "qwen/qwen3-8b",
        # "qwen/qwen3-14b",
        
        # "meta-llama/llama-3.1-8b-instruct",
        # "meta-llama/llama-3.2-3b-instruct",
        # "meta-llama/llama-3.2-1b-instruct",
    ]
    
    # Initialize token usage tracking dictionary
    token_usage = {
        "free_form_input": 0,
        "free_form_output": 0,
        "mcq_input": 0,
        "mcq_output": 0
    }
    
    num_questions = {"free_form": 0.0, "mcq": 0.0}
    cost_usage = { x:0.0 for x in token_usage.keys() }
    
    matcher_usage = { model: {
                    "input": 0,
                    "output": 0,
                    } 
                   for model in models if "o4" not in model.lower()}
    
    judge_usage = { model: {
                    "input": 0,
                    "output": 0,
                    } 
                   for model in models if "qwen" not in model.lower()}
    
    # Initialize combined samples list
    combined_samples = []
    acc = []
    total = []
    processes_files = []
    
    # Process each file and extract samples in batches
    for root, _, files in os.walk(input_dir):
        if "wrong" in root or "wrong" in _ :
            continue 
        for file in files:
            model_file = os.path.join(root, file)
            # print(model_file)
            
            if "gen" not in model_file and "mcq" not in model_file :
                continue 
            
            # if "gpqa" not in model_file :
            #     continue 
            
            if "mmlu" not in model_file :
                continue 
            
            logger.info(f"Processing file: {model_file}")
            # Load file data
            model_data = load_jsonl_file(model_file)
            logger.info(f"Num items: {len(model_data)}")
            
            # Track token usage for each item
            for item in model_data:
                
                if "flash" in item["model"].lower() :
                    continue
                
                # Skip qwen3 
                # if "maverick" not in item["model"].lower(): 
                #     continue 
                    
                if not "qwen3" in item["model"].lower() :
                    continue 
            
                flag = False 
                    
                # Process each model for this file
                for j, model in enumerate(models):
                    model_name = model.split("/")[-1]
                    # logger.info(f"Processing model: {model_name}")
                    
                     
                    is_mcq = "mcq" in model_file #  in item["dataset"]
                    
                    # Check for model-specific completion tokens
                    prompt_token_field = f"prompt_tokens_{model_name}"
                    completion_token_field = f"completion_tokens_{model_name}"
                    
                    if "gen" in model_file and completion_token_field not in item:
                        continue 
                    
                    if prompt_token_field in item:
                        field = prompt_token_field
                        used = int(item[field]) if not isinstance(item[field], list) else int(item[field][0])
                        matcher_usage[model]["input"] += used
                    
                    if completion_token_field in item:
                        field = completion_token_field
                        used = int(item[field]) if not isinstance(item[field], list) else int(item[field][0])
                        
                        if is_mcq:
                            assert False
                            token_usage["mcq_output"] += used
                        else:
                            matcher_usage[model]["output"] += used
                            
                    # else :
                    #     continue 
                    
                    suffix =  "-JUDGE"
                    judge_completion_token_field = f"completion_tokens_{model_name}{suffix}"
                    judge_prompt_token_field = f"prompt_tokens_{model_name}{suffix}"
                    
                    if judge_completion_token_field in item:
                        field = judge_completion_token_field
                        used = int(item[field]) if not isinstance(item[field], list) else int(item[field][0])
                        judge_usage[model]["output"] += used
                        
                    if judge_prompt_token_field in item:
                        field = judge_prompt_token_field
                        used = int(item[field]) if not isinstance(item[field], list) else int(item[field][0])
                        judge_usage[model]["input"] += used
                        
                        
                student_model = item["model"]
                
                #  "prompt_tokens" not in item or
                if "completion_tokens" not in item or "dataset" not in item :
                    continue 
                
                # Determine if it's MCQ or free-form based on question format
                is_mcq = "mcq" in model_file #  in item["dataset"]
                
                
                # Check if item has token usage information
                if "prompt_tokens" in item and "qwen" not in student_model.lower():
                    field = "prompt_tokens"
                    used = int(item[field]) if not isinstance(item[field], list) else int(item[field][0])
                    
                    
                    if is_mcq:
                        token_usage["mcq_input"] += used
                        cost_usage["mcq_input"] += used * MODEL_COST[student_model]["input"]
                    else:
                        token_usage["free_form_input"] += used 
                        cost_usage["free_form_input"] += used * MODEL_COST[student_model]["input"]
                                                
                # Check for completion tokens in standard field
                if "completion_tokens" in item:
                    
                    field = "completion_tokens"
                    used = int(item[field]) if not isinstance(item[field], list) else int(item[field][0])
                    
                    if is_mcq:
                        token_usage["mcq_output"] += used
                        num_questions["mcq"] += 1
                        cost_usage["mcq_output"] += used * MODEL_COST[student_model]["output"]
                    else:
                        token_usage["free_form_output"] += used
                        num_questions["free_form"] += 1
                        cost_usage["free_form_output"] += used * MODEL_COST[student_model]["output"]
                
            processes_files.append(model_file)  
                
            logger.info(f"Free-form input tokens: {token_usage['free_form_input']}")
            logger.info(f"Free-form output tokens: {token_usage['free_form_output']}")
            logger.info(f"MCQ input tokens: {token_usage['mcq_input']}")
            logger.info(f"MCQ output tokens: {token_usage['mcq_output']}")
            logger.info("\n\n\n--------------------------------\n\n\n")
                
    actual_files = list(set(processes_files))
    # Log token usage statistics
    logger.info("Token usage statistics:")
    logger.info(f"Free-form input tokens: {token_usage['free_form_input']}")
    logger.info(f"Free-form output tokens: {token_usage['free_form_output']}")
    logger.info(f"MCQ input tokens: {token_usage['mcq_input']}")
    logger.info(f"MCQ output tokens: {token_usage['mcq_output']}")
    logger.info(f"Actual files: {actual_files}")
    
    cost_usage["free_form_input"] = cost_usage["free_form_input"] / num_questions["free_form"]
    cost_usage["free_form_output"] = cost_usage["free_form_output"] / num_questions["free_form"]
    cost_usage["mcq_input"] = cost_usage["mcq_input"] / num_questions["mcq"]
    cost_usage["mcq_output"] = cost_usage["mcq_output"] / num_questions["mcq"]
    
    for model in models:
        logger.info(f"Model: {model}")
        if model in matcher_usage:
            matcher_usage[model]["input"] = matcher_usage[model]["input"] / num_questions["free_form"]
            matcher_usage[model]["output"] = matcher_usage[model]["output"] / num_questions["free_form"]
            
            logger.info(f"Input Matcher tokens: {matcher_usage[model]['input']}")
            logger.info(f"Output Matcher tokens: {matcher_usage[model]['output']}")
        if model in judge_usage:
            judge_usage[model]["input"] = judge_usage[model]["input"] / num_questions["free_form"]
            judge_usage[model]["output"] = judge_usage[model]["output"] / num_questions["free_form"]
            
            logger.info(f"Input Judge tokens: {judge_usage[model]['input']}")
            logger.info(f"Output Judge tokens: {judge_usage[model]['output']}")
        logger.info("\n\n\n--------------------------------\n\n\n")
        
    logger.info(f"Free-form input tokens: {token_usage['free_form_input']}")
    logger.info(f"Free-form output tokens: {token_usage['free_form_output']}")
    logger.info(f"MCQ input tokens: {token_usage['mcq_input']}")
    logger.info(f"MCQ output tokens: {token_usage['mcq_output']}")
    
    logger.info(f"Free-form input cost: {cost_usage['free_form_input']}")
    logger.info(f"Free-form output cost: {cost_usage['free_form_output']}")
    logger.info(f"MCQ input cost: {cost_usage['mcq_input']}")
    logger.info(f"MCQ output cost: {cost_usage['mcq_output']}")
    
    for model in models:
        logger.info(f"Model: {model}")
        if model in matcher_usage: 
            matcher_usage[model]["input"] = matcher_usage[model]["input"] * MODEL_COST[model]["input"]
            matcher_usage[model]["output"] = matcher_usage[model]["output"] * MODEL_COST[model]["output"]
            
            logger.info(f"Matcher input cost: {matcher_usage[model]['input']}")
            logger.info(f"Matcher output cost: {matcher_usage[model]['output']}")
        
        if model in judge_usage:
            judge_usage[model]["input"] = judge_usage[model]["input"] * MODEL_COST[model]["input"]
            judge_usage[model]["output"] = judge_usage[model]["output"] * MODEL_COST[model]["output"]
            
            logger.info(f"Judge input cost: {judge_usage[model]['input']}")
            logger.info(f"Judge output cost: {judge_usage[model]['output']}")
        
        logger.info("\n\n\n--------------------------------\n\n\n")
    
    logger.info(f"Num questions: {num_questions}")
    return cost_usage, token_usage, matcher_usage, judge_usage, num_questions

def plot_cost_analysis(cost_usage, matcher_usage, judge_usage, num_questions):
    """
    Create a vertical bar chart showing cost breakdown across different models and tasks.
    
    Args:
        cost_usage: Dictionary containing MCQ and free-form costs
        matcher_usage: Dictionary containing matcher costs by model
        judge_usage: Dictionary containing judge costs by model
    """
    # Convert costs from dollars to millions of dollars and combine input/output
    mcq_cost = (cost_usage["mcq_input"] + cost_usage["mcq_output"]) / 1e6
    free_form_cost = (cost_usage["free_form_input"] + cost_usage["free_form_output"]) / 1e6
    
    # Get list of models
    models = list(matcher_usage.keys())
    
    # Prepare data for plotting
    labels = ["MCQ"]
    
    # Add matcher and judge model labels
    for model in models:
        if model not in matcher_usage:
            continue
        
        actual_name = MODEL_NAMES[model]
        model_name = model.split("/")[-1]
        labels.append(f"Matcher\n{actual_name}")
    
    for model in models:
        if model not in judge_usage:
            continue
        
        actual_name = MODEL_NAMES[model]
        model_name = model.split("/")[-1]
        labels.append(f"Judge\n{actual_name}")
    
    # Create figure and axis with wider, less tall dimensions
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Set up colors - use a single color for model responses (both MCQ and free-form)
    model_response_color = '#4682B4'  # Steel blue for all model responses
    matcher_color = '#32CD32'         # Green for matcher
    judge_color = '#FF6347'           # Red for judge
    
    # Set up positions for bars with wider spacing
    bar_width = 0.4
    x_pos = np.arange(len(labels)) * 0.75 # Increase spacing between bars
    
    print(mcq_cost)
    # Plot MCQ bar (combined input and output)
    ax.bar(x_pos[0], mcq_cost, width=bar_width, color=model_response_color, label='Model Response')
    
    # Plot matcher bars
    j = 0
    for model in models:
        if model not in matcher_usage:
            continue
        
        pos = j + 1
        j += 1
        
        # Free-form (combined input and output) - now labeled as "Model Response"
        ax.bar(x_pos[pos], free_form_cost, width=bar_width, color=model_response_color, label='Model Response')
        
        # Matcher (combined input and output)
        matcher_cost = (matcher_usage[model]["input"] + matcher_usage[model]["output"]) / 1e6
        bottom_pos = free_form_cost
        print(model, matcher_cost)
        ax.bar(x_pos[pos], matcher_cost, width=bar_width, bottom=bottom_pos, color=matcher_color, label='Matcher')
    # Plot judge bars
    j = 0
    for i, model in enumerate(models):
        if model not in judge_usage:
            continue
         
        pos = j + 1 + len(models)
        j += 1
        
        # Free-form (combined input and output) - now labeled as "Model Response"
        ax.bar(x_pos[pos], free_form_cost, width=bar_width, color=model_response_color)
        
        # Judge (combined input and output)
        judge_cost = (judge_usage[model]["input"] + judge_usage[model]["output"]) / 1e6
        print(model, judge_cost)
        bottom_pos = free_form_cost
        ax.bar(x_pos[pos], judge_cost, width=bar_width, bottom=bottom_pos, color=judge_color, label='Judge')
    
    # Add labels and title with larger font sizes
    ax.set_xticks(x_pos)
    ax.set_xticklabels(labels, fontsize=22, ha='center')
    ax.set_ylabel('Mean Cost Per Sample (USD)', fontsize=24)
    ax.tick_params(axis='y', labelsize=22)  # Increased y-tick font size
    
    # Set y-axis to use scientific notation with proper tick spacing
    ax.set_ylim(0, 0.8e-3)
    
    # Create evenly spaced ticks at 0.5e-3 intervals
    yticks = np.arange(0, 0.8e-3, 0.2e-3)
    ax.set_yticks(yticks)
    
    # Format the tick labels
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f'{x*1e3:.1f}'))
    ax.text(-0.01, 1.01, r'$\times 10^{-3}$', transform=ax.transAxes, fontsize=20)
    
    # Define consistent font size for legend
    legend_fontsize = 16
    
    # Add a legend - only add one instance of each label - horizontal orientation
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, color=model_response_color, label='Model Response'),
        plt.Rectangle((0, 0), 1, 1, color=matcher_color, label='Matcher'),
        plt.Rectangle((0, 0), 1, 1, color=judge_color, label='Judge')
    ]
    ax.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.56, 1.1),
              ncol=3, frameon=True, fontsize=legend_fontsize)
    
    # ax.legend(handles=legend_elements, loc='upper left', frameon=True, fontsize=legend_fontsize)
    
    # Add grid lines that align with the tick marks
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Create directory if it doesn't exist
    os.makedirs('plots', exist_ok=True)
    
    # Adjust layout with more padding
    plt.tight_layout(pad=2.0)
    # Save the figure
    # plt.savefig('plots/cost_analysis3.pdf', dpi=300, bbox_inches='tight')
    plt.savefig('plots/cost_analysis4.png', dpi=300, bbox_inches='tight')
    logger.info("Cost analysis plot saved as 'plots/cost_analysis.png'")
    # Show the plot
    plt.show()

def main():
    parser = argparse.ArgumentParser(description="Combine model outputs for annotation")
    parser.add_argument("--input_dir", default="/fast/XXXX-1/qaevals/judge_outputs/alignment_plot/",
                        help="Directory containing model output files")
    
    args = parser.parse_args()
    
    cost_usage, token_usage, matcher_usage, judge_usage, num_questions = get_samples(
        input_dir=args.input_dir,
    )
    
    plot_cost_analysis(cost_usage, matcher_usage, judge_usage, num_questions)

if __name__ == "__main__":
    main()
