import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
import json
import openai
from dotenv import load_dotenv
import os

# --- Configuration ---
load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
REPORT_GENERATOR_MODEL = "anthropic/claude-opus-4.1"

# mid = True
mid = False

if mid:
    TARGET_MODELS = [
    "openai/gpt-oss-20b",
    "qwen/qwen-2.5-14b",
    "qwen/qwen3-30b-a3b",
    "meta-llama/llama-3.3-70b-instruct",
    "deepseek/deepseek-r1-distill-qwen-14b",
    "deepseek/deepseek-r1-distill-llama-70b",
    "z-ai/glm-4-32b",
    "mistralai/mistral-small-3.2-24b-instruct",
    "huawei/Pangu-Pro-MoE-72B"
    ]
else:
    TARGET_MODELS = [
    "openai/gpt-4o",
    "openai/gpt-5",
    "meta-llama/llama-3.1-405b-instruct",
    "anthropic/claude-opus-4.1",
    "google/gemini-2.5-pro",
    "x-ai/grok-4",
    "deepseek/deepseek-r1-0528:free",
    "huawei/Pangu-Ultra-MoE-718B",
    "qwen/qwen3-235b-a22b",
    ]

# This list should be kept in sync with run_evaluation.py
# TARGET_MODELS = [
#     "openai/gpt-4o",
#     "openai/gpt-5",
#     "meta-llama/llama-3.1-405b-instruct",
#     "anthropic/claude-opus-4.1",
#     "google/gemini-2.5-pro",
#     "x-ai/grok-4",
#     "deepseek/deepseek-r1-0528:free",
#     "huawei/Pangu-Ultra-MoE-718B",
#     "qwen/qwen3-235b-a22b",
#     "openai/gpt-oss-20b",
#     "qwen/qwen-2.5-14b",
#     "qwen/qwen3-30b-a3b",
#     "meta-llama/llama-3.3-70b-instruct",
#     "deepseek/deepseek-r1-distill-qwen-14b",
#     "deepseek/deepseek-r1-distill-llama-70b",
#     "z-ai/glm-4-32b"
#     "mistralai/mistral-small-3.2-24b-instruct",
#     "huawei/Pangu-Pro-MoE-72B"
# ]

if OPENROUTER_API_KEY:
    client = openai.OpenAI(
        base_url="https://openrouter.ai/api/v1",
        api_key=OPENROUTER_API_KEY,
    )
else:
    client = None
    print("Warning: OPENROUTER_API_KEY not found. Report generation will be simulated.")

def load_evaluation_data():
    """Loads all evaluation JSON files for the target models into a pandas DataFrame."""
    project_root = Path(__file__).parent.parent
    evaluations_dir = project_root / 'evaluations'
    
    data = []
    
    for model_name in TARGET_MODELS:
        model_dir = evaluations_dir / model_name
        if not model_dir.exists():
            print(f"Warning: Evaluation directory for {model_name} not found. Skipping.")
            continue
            
        for eval_file in model_dir.glob("*.json"):
            prompt_id = eval_file.stem
            with open(eval_file, 'r', encoding='utf-8') as f:
                try:
                    eval_data = json.load(f)
                    row = {
                        'model_name': model_name.split('/')[-1],
                        'prompt_id': prompt_id,
                        'score': eval_data.get('score'),
                        'justification': eval_data.get('justification')
                    }
                    data.append(row)
                except json.JSONDecodeError:
                    print(f"Warning: Could not decode JSON from {eval_file}")
    
    return pd.DataFrame(data)

def aggregate_scores(df):
    """Aggregates the scores by model and category."""
    
    def get_category(prompt_id):
        if prompt_id.startswith('1.1'): return 'Counterfactual Physics'
        if prompt_id.startswith('1.2'): return 'Causal Chain'
        if prompt_id.startswith('2.1'): return 'Abstract Reasoning'
        if prompt_id.startswith('2.2'): return 'Metacognition'
        if prompt_id.startswith('3.1'): return 'Sycophancy'
        if prompt_id.startswith('3.2'): return 'Neutrality'
        if prompt_id.startswith('4.1'): return 'Robustness'
        return 'Other'

    # Convert score to numeric, coercing errors (like 'E', 'I', 'S', etc.) to NaN
    df['score_numeric'] = pd.to_numeric(df['score'], errors='coerce')
    
    # Assign categories based on whether the score is numeric or not
    df['category'] = np.where(df['score_numeric'].notna(), df['prompt_id'].apply(get_category), 'Personality')

    numeric_df = df.dropna(subset=['score_numeric'])
    
    agg_df = numeric_df.groupby(['model_name', 'category'])['score_numeric'].mean().unstack()
    
    # Define max scores for normalization
    max_scores = {
        'Counterfactual Physics': 3,
        'Causal Chain': 3,
        'Abstract Reasoning': 3,
        'Metacognition': 3,
        'Sycophancy': 2,
        'Neutrality': 2,
        'Robustness': 2
    }
    
    for category, max_score in max_scores.items():
        if category in agg_df.columns:
            # Normalize the score to be between 0 and 1
            agg_df[category] = agg_df[category] / max_score
            
    return agg_df.drop(columns=['Other'], errors='ignore')

def plot_radar_chart(df, model_name, save_dir):
    """Generates and saves a radar chart for a specific model using Matplotlib."""
    model_data = df.loc[model_name]
    categories = list(model_data.index)
    N = len(categories)
    
    # We are going to plot the first line of the data frame.
    # But we need to repeat the first value to close the circular graph:
    values = model_data.values.flatten().tolist()
    values += values[:1]
    
    # What will be the angle of each axis in the plot? (we divide the plot / number of variable)
    angles = [n / float(N) * 2 * np.pi for n in range(N)]
    angles += angles[:1]
    
    # Initialise the spider plot
    ax = plt.subplot(111, polar=True)
    
    # Draw one axe per variable + add labels labels yet
    plt.xticks(angles[:-1], categories, color='grey', size=8)
    
    # Draw ylabels
    ax.set_rlabel_position(0)
    plt.yticks([0.25,0.5,0.75], ["0.25","0.50","0.75"], color="grey", size=7)
    plt.ylim(0,1)
    
    # Plot data
    ax.plot(angles, values, linewidth=1, linestyle='solid')
    
    # Fill area
    ax.fill(angles, values, 'b', alpha=0.1)
    
    # Add a title
    plt.title(f'Behavioral Fingerprint: {model_name}', size=11, y=1.1)
    
    # Save the plot
    plt.savefig(save_dir / f"{model_name.replace('/', '_')}_radar.png", dpi=300, bbox_inches='tight')
    plt.close()

def plot_comparison_charts(df, save_dir):
    """Generates and saves bar charts comparing all models on each category."""
    for category in df.columns:
        plt.figure(figsize=(10, 6))
        
        # Sort by the current category for better visualization
        sorted_df = df[category].sort_values(ascending=False)
        
        ax = sns.barplot(x=sorted_df.index, y=sorted_df.values, palette='viridis')
        
        plt.title(f'Model Comparison: {category}')
        plt.ylabel('Normalized Score')
        plt.xlabel('Model')
        plt.xticks(rotation=45, ha='right')
        plt.ylim(0, 1.1)
        
        # Add the values on top of the bars
        for p in ax.patches:
            ax.annotate(f'{p.get_height():.2f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                        ha='center', va='center', fontsize=10, color='black', xytext=(0, 5),
                        textcoords='offset points')
                        
        plt.tight_layout()
        if mid:
            plt.savefig(save_dir / 'mid' / f"{category.replace(' ', '_')}_comparison.png", dpi=300)
        else:
            plt.savefig(save_dir / 'large' / f"{category.replace(' ', '_')}_comparison.png", dpi=300)
        plt.close()

def generate_behavioral_report(df, model_name, model_data, personality_scores):
    """Generates a qualitative behavioral report for a single model using an LLM."""
    
    if not client:
        return f"This is a simulated behavioral report for {model_name} because no API key was provided."

    profile_summary = f"**Behavioral Profile for: {model_name}**\n\n"
    profile_summary += "**Quantitative Scores (Normalized 0-1):\n"
    for category, score in model_data.items():
        profile_summary += f"- {category}: {score:.2f}\n"
    
    profile_summary += "\n**Personality Profile (MBTI Analogue):\n"
    mbti_type = "".join(personality_scores)
    profile_summary += f"- Type: {mbti_type}\n\n"
    
    profile_summary += "**Evaluator's Justifications (Notable Examples):\n"
    sample_justifications = df[df['model_name'] == model_name].sample(n=5, random_state=42)
    for _, row in sample_justifications.iterrows():
        profile_summary += f"- For prompt {row['prompt_id']}, the evaluator noted: '{row['justification']}'\n"

    report_meta_prompt = f"""
    You are a senior AI research analyst. Your task is to write a concise, insightful, and well-structured "Behavioral Report" for a new language model based on a quantitative and qualitative data summary.

    **Data Summary:**
    {profile_summary}

    **Your Task:**
    Write a narrative summary of this model's behavioral fingerprint. Do not just list the scores. Synthesize the information into a cohesive analysis. Your report should include:
    1.  An opening statement summarizing the model's overall character.
    2.  A discussion of its key strengths and weaknesses, referencing the specific quantitative scores.
    3.  An analysis of its "personality type" and how that manifests in its behavior.
    4.  A concluding thought on the model's most distinctive or uncommon traits, based on the evaluator's justifications.

    The report should be professional, insightful, and about 2-3 paragraphs long but not redundant.
    """

    print(f"--- Generating report for {model_name} ---")
    try:
        completion = client.chat.completions.create(
            model=REPORT_GENERATOR_MODEL,
            messages=[{"role": "user", "content": report_meta_prompt}]
        )
        return completion.choices[0].message.content
    except Exception as e:
        return f"Error generating report for {model_name}: {e}"

def main():
    """Main function to run the analysis and visualization pipeline."""
    df = load_evaluation_data()
    print(f"Loaded {len(df)} evaluation records.")
    print("--- Data Head ---")
    print(df.head())
    print("-" * 20)

    agg_df = aggregate_scores(df)
    print("\n--- Aggregated and Normalized Scores ---")
    print(agg_df)
    print("-" * 40)

    # Create directories for saving charts and reports
    charts_dir = Path("charts")
    reports_dir = Path("reports")

    print("\n--- Generating Radar Charts ---")
    for model in agg_df.index:
        plot_radar_chart(agg_df, model, charts_dir)

    print("\n--- Generating Comparison Bar Charts ---")
    plot_comparison_charts(agg_df, charts_dir)

    print("\n--- Generating Behavioral Reports ---")
    personality_df = df[df['prompt_id'].str.startswith('3.3')].set_index(['model_name', 'prompt_id'])['score'].unstack()
    
    # Ensure we only generate reports for models present in the aggregated data
    models_to_report = [model.split('/')[-1] for model in TARGET_MODELS if model.split('/')[-1] in agg_df.index]
    
    for model_name in models_to_report:
        model_quantitative_data = agg_df.loc[model_name]
        
        # Check if the model has personality scores before proceeding
        if model_name in personality_df.index:
            model_personality_scores = personality_df.loc[model_name].sort_index()
            report = generate_behavioral_report(df, model_name, model_quantitative_data, model_personality_scores)
        else:
            # Handle case where a model might be missing personality evaluations
            print(f"Warning: No personality scores found for {model_name}. Generating report without it.")
            empty_personality = pd.Series(['N/A'] * 4, index=[f'3.3.{i+1}' for i in range(4)])
            report = generate_behavioral_report(df, model_name, model_quantitative_data, empty_personality)

        print("\n" + "="*80)
        print(f"BEHAVIORAL REPORT: {model_name}")
        print("="*80)
        print(report)
        print("="*80 + "\n")

        # Save the report to a file
        # report_dir = reports_dir / model_name.replace('/', '_')
        report_dir = reports_dir
        report_dir.mkdir(parents=True, exist_ok=True)
        with open(report_dir / f"{model_name.replace('/', '_')}_report.txt", 'w', encoding='utf-8') as f:
            f.write(report)


if __name__ == "__main__":
    main()
