# %%
import os
from datasets import load_dataset
import numpy as np
import pandas as pd
from tqdm.contrib.concurrent import process_map
from tqdm import tqdm
from utils import response_text_without_think, event_filter, calculate_event_type, plot_score_distributions, env_pred_score
import matplotlib.pyplot as plt
import seaborn as sns

MODEL_PATHS = {
    # "0.6B-Baseline": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-0.6B",
    #     "result.jsonl"
    # ),
    # "0.6B-no_think": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-0.6B-no_think",
    #     "result.jsonl"
    # ),
    # "1.7B-Baseline": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-1.7B",
    #     "result.jsonl"
    # ),
    # "1.7B-no_think": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-1.7B-no_think",
    #     "result.jsonl"
    # ),
    # "4B-Baseline": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-4B",
    #     "result.jsonl"
    # ),
    # "4B-Baseline-repeat": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-4B-repeat",
    #     "result.jsonl"
    # ),
    # "4B-no_think": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-4B-no_think",
    #     "result.jsonl"
    # ),
    # "4B-Thinking-Baseline": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-4B-Thinking",
    #     "result.jsonl"
    # ),
    # "8B-Baseline": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-8B",
    #     "result.jsonl"
    # ),
    # "8B-no_think": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-8B-no_think",
    #     "result.jsonl"
    # ),
    # "8B-DeepSeek-distill": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "DeepSeek-R1-0528-Qwen3-8B",
    #     "result.jsonl"
    # ),
    # "32B-Baseline": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-32B",
    #     "result.jsonl"
    # ),
    # "32B-no_think": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-32B-no_think",
    #     "result.jsonl"
    # ),
    # "235B-Baseline": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-235B",
    #     "result.jsonl"
    # ),
    # "235B-no_think": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-235B-no_think",
    #     "result.jsonl"
    # ),
    # "235B-Thinking-Baseline": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Qwen3-235B-Thinking",
    #     "result.jsonl"
    # ),
    # "HuatuoGPT-o1-8B": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "HuatuoGPT-o1-8B",
    #     "result.jsonl"
    # ),
    # "HuatuoGPT-o1-7B": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "HuatuoGPT-o1-7B",
    #     "result.jsonl"
    # ),
    # "HuatuoGPT-o1-70B": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "HuatuoGPT-o1-70B",
    #     "result.jsonl"
    # ),
    # "HuatuoGPT-o1-72B": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "HuatuoGPT-o1-72B",
    #     "result.jsonl"
    # ),
    # "Baichuan-M2-32B": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Baichuan-M2-32B",
    #     "result.jsonl"
    # ),
    # "Baichuan-M2-32B-no_think": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "Baichuan-M2-32B-no_think",
    #     "result.jsonl"
    # ),
    # "gpt-oss-120b": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_baseline",
    #     "gpt-oss-120b",
    #     "result.jsonl"
    # ),
    # "4B-SFT": os.path.join(
    #     os.path.dirname(__file__),
    #     "..",
    #     "model",
    #     "env_sft",
    #     "Qwen3-4B-SFT",
    #     "checkpoint-385",
    #     "result.jsonl"
    # ),
    "4B-GRPO-final-ckpt200": os.path.join(
        os.path.dirname(__file__),
        "..",
        "model",
        "env_grpo",
        "Qwen3-4B-final",
        "checkpoint-200",
        "result.jsonl"
    ),
}

TARGET_EVENT_TYPES = ["RadiologyEvent", "MicrobiologyEvent", "LabEvent"]


def process_data_entry(data):
    ground_truth = data['ground_truth']
    response_text = data['response_text']

    response_text = response_text_without_think(response_text)

    score = env_pred_score(ground_truth, response_text)
    event_type = calculate_event_type(data)

    return {'score': score, 'event_type': event_type}


def plot_event_type_scores(metrics_df):

    print("\n--- Generating performance heatmap ---")

    plt.figure(figsize=(12, 10))

    heatmap = sns.heatmap(
        metrics_df,
        annot=True,
        fmt=".4f",
        cmap="YlGnBu",
        linewidths=.5,
        linecolor='gray',
    )

    heatmap.set_title('Model Performance Heatmap',
                      fontdict={'fontsize': 16}, pad=12)
    heatmap.set_xlabel('Metrics', fontdict={'fontsize': 12})
    heatmap.set_ylabel('Models', fontdict={'fontsize': 12})

    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)

    plt.tight_layout()
    plt.show()

all_results = {}


for model_name, path in MODEL_PATHS.items():
    print(f"\n--- Processing model: {model_name} ---")
    if not os.path.exists(path):
        print(f"Warning: Path for model '{model_name}' not found. Skipping. Path: {path}")
        continue

    dataset = load_dataset("json", data_files=path, split="train")
    print(f"Loaded {len(dataset)} records for {model_name}.")

    if len(dataset) == 0:
        print(f"Model '{model_name}' has no data. Skipping.")
        continue

    results_list = process_map(
        process_data_entry, dataset, max_workers=40, chunksize=16,
        desc=f"Calculating scores and event types for {model_name}",
    )
    all_results[model_name] = results_list

print("\n--- Aggregating scores ---")
final_metrics = {}
for model_name, results in all_results.items():
    scores_by_type = {event_type: [] for event_type in TARGET_EVENT_TYPES}
    all_scores_for_model = [res['score'] for res in results]

    for result in results:
        if result['event_type'] in TARGET_EVENT_TYPES:
            scores_by_type[result['event_type']].append(result['score'])

    model_metrics = {}
    macro_scores_list = []
    for event_type in TARGET_EVENT_TYPES:
        avg_score = np.mean(
            scores_by_type[event_type]) if scores_by_type[event_type] else 0.0
        model_metrics[event_type] = avg_score
        macro_scores_list.append(avg_score)

    model_metrics['Overall'] = np.mean(
        all_scores_for_model) if all_scores_for_model else 0.0
    model_metrics['Macro'] = np.mean(
        macro_scores_list) if macro_scores_list else 0.0
    final_metrics[model_name] = model_metrics

print("\n--- Model Performance Evaluation Table ---")
metrics_df = pd.DataFrame.from_dict(final_metrics, orient='index')

metrics_df = metrics_df.rename(columns={
    "RadiologyEvent": "Radiology",
    "MicrobiologyEvent": "Microbiology",
    "LabEvent": "Lab"
})

column_order = ['Radiology', 'Microbiology', 'Lab', 'Overall', 'Macro']
metrics_df = metrics_df[column_order]
print(metrics_df)

plot_event_type_scores(metrics_df)
