import os
import json
import matplotlib.pyplot as plt
import re
import numpy as np

# Set global plotting parameters for large text and high clarity
plt.rcParams.update({
    'font.size': 20,
    'axes.labelsize': 24,
    'axes.titlesize': 24,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20,
    'legend.fontsize': 20,
    'lines.linewidth': 4,
    'lines.markersize': 14,
    'figure.figsize': (16, 10)
})

# Anonymous Paths
ABS_DATA_DIR = "./result"
BUCKET_DIR = "./result_buckets"

data_records = []

# Parse absolute zero results
raw_abs_data = {} # Key: (gen_sz, dtype, test_sz), Value: data_list

if os.path.exists(ABS_DATA_DIR):
    for dirname in os.listdir(ABS_DATA_DIR):
        if dirname.startswith("."): continue
        match = re.search(r"qwen2\.5-([\w\.]+)-instruct_(.+)_code_([fio])", dirname)
        if match:
            test_sz = match.group(1)
            gen_sz = match.group(2)
            dtype = match.group(3)
            
            # Normalize naming: map models to clean display names
            # (Keeping the categorization for Qwen models)
            display_gen_sz = '14b' if gen_sz in ['14b', '14b_coder'] else gen_sz
            
            result_path = os.path.join(ABS_DATA_DIR, dirname, "metrics_history.json")
            if os.path.exists(result_path):
                try:
                    with open(result_path, 'r') as f:
                        data_list = json.load(f)
                        if isinstance(data_list, list) and len(data_list) > 0:
                            raw_abs_data[(display_gen_sz, dtype, test_sz)] = data_list
                except Exception as e:
                    print(f"Error reading {result_path}: {e}")

# Process raw data using MDL minimization
for key, data_list in raw_abs_data.items():
    gen_sz, dtype, test_sz = key
    
    # logic: choose the epoch with minimum mdl_per_token
    best_entry = min(data_list, key=lambda x: x.get("mdl_per_token", float('inf')))
    ep = best_entry.get("epiplexity")
    
    if ep is not None:
        data_records.append({
            "source": "absolute",
            "test_sz": test_sz,
            "gen_sz": gen_sz,
            "dtype": dtype,
            "epiplexity": ep
        })

# Parse bucket results
raw_bucket_data = [] # List of (test_sz, dtype, bucket_id, data_list)

if os.path.exists(BUCKET_DIR):
    for dirname in os.listdir(BUCKET_DIR):
        if dirname.startswith("."): continue
        match = re.search(r"qwen2\.5-([\w\.]+)-instruct_([fio])_bucket_(\d+)_train", dirname)
        if match:
            test_sz = match.group(1)
            dtype = match.group(2)
            bucket_id = int(match.group(3))
            
            result_path = os.path.join(BUCKET_DIR, dirname, "metrics_history.json")
            if os.path.exists(result_path):
                try:
                    with open(result_path, 'r') as f:
                        data_list = json.load(f)
                        if isinstance(data_list, list) and len(data_list) > 0:
                            raw_bucket_data.append((test_sz, dtype, bucket_id, data_list))
                except Exception as e:
                    print(f"Error reading {result_path}: {e}")

# Helper to sort size strings
def size_to_float(s):
    return float(s.replace('b', ''))

# Chart 1: Epiplexity vs Solver Size
records_by_dtype = {}
for r in data_records:
    if r["source"] == "absolute":
        if r["dtype"] not in records_by_dtype:
            records_by_dtype[r["dtype"]] = []
        records_by_dtype[r["dtype"]].append(r)

dtypes = sorted(records_by_dtype.keys())
fig, axes = plt.subplots(1, len(dtypes), figsize=(28, 8), sharey=False)
if len(dtypes) == 1: axes = [axes]

legend_handles = {}
type_map = {
    'i': 'Abduction',
    'o': 'Deduction',
    'f': 'Induction'
}

for idx, (ax, dtype) in enumerate(zip(axes, dtypes)):
    records = records_by_dtype[dtype]
    
    groups_gen = {}
    for r in records:
        if r["gen_sz"] not in groups_gen:
            groups_gen[r["gen_sz"]] = []
        groups_gen[r["gen_sz"]].append(r)
        
    gen_models = sorted(groups_gen.keys())
    
    name_map = {
        '7b': 'Qwen2.5 7B',
        '14b': 'Qwen2.5 14B',
        'qwen3_4b': 'Qwen3 4B'
    }

    marker_map = {
        '14b': 'o',
        '7b': 's',
        'qwen3_4b': '^'
    }
    default_markers = ['o', 's', '^', 'v', 'D', 'x', '*', '+']
    
    for i, gen_sz in enumerate(gen_models):
        if gen_sz == "llama_64" or "coder" in gen_sz:
            continue
            
        group = groups_gen[gen_sz]
        group.sort(key=lambda x: size_to_float(x["test_sz"]))
        x_vals = [r["test_sz"] for r in group]
        y_vals = [r["epiplexity"] for r in group]
        
        marker = marker_map.get(gen_sz, default_markers[i % len(default_markers)])
        label_name = name_map.get(gen_sz, gen_sz)

        lines = ax.plot(x_vals, y_vals, 
                marker=marker, 
                linestyle='-', 
                label=label_name, 
                alpha=0.9,
                linewidth=5,
                markersize=10)
        
        if label_name not in legend_handles:
            legend_handles[label_name] = lines[0]
    
    ax.set_title(type_map.get(dtype, dtype))
    if idx == 0:
        ax.set_ylabel("Epiplexity")
    ax.grid(True, alpha=0.15)

fig.text(0.5, 0.04, "Solver LLM Size", ha='center', va='center', fontsize=24)

if legend_handles:
    handles = list(legend_handles.values())
    labels = list(legend_handles.keys())
    
    # Simple Legend arrangement
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.99), 
               ncol=len(labels), frameon=True, edgecolor='black', framealpha=1,
               handletextpad=0.2, columnspacing=1.5)

plt.tight_layout(rect=[0, 0.05, 1, 0.90])
plt.savefig("chart1_epiplexity_vs_test_size.png")
plt.savefig("chart1_epiplexity_vs_test_size.pdf")
plt.close()

# Chart 2: Bucket (Self-play) Trends
def plot_buckets_chart(filename, target_sz="3b"):
    plot_records = []
    
    for test_sz, dtype, bucket_id, data_list in raw_bucket_data:
        best_entry = min(data_list, key=lambda x: x.get("mdl_per_token", float('inf')))
        epiplexity = best_entry.get("epiplexity")

        if epiplexity is not None:
             plot_records.append({
                "test_sz": test_sz,
                "dtype": dtype,
                "bucket": bucket_id,
                "epiplexity": epiplexity
             })
             
    recs_target = [r for r in plot_records if r["test_sz"] == target_sz and r["bucket"] != 0]
    
    if not recs_target:
        print(f"No records found for model size {target_sz}")
        return

    fig, axes = plt.subplots(2, 1, figsize=(14, 12), sharex=True)
    
    type_colors = {'f': '#1f77b4', 'i': '#ff7f0e', 'o': '#2ca02c'}
    
    def plot_on_ax(ax, recs):
        if not recs:
            ax.text(0.5, 0.5, "No Data", ha='center', va='center')
            return

        by_dtype = {}
        for r in recs:
            if r["dtype"] not in by_dtype: by_dtype[r["dtype"]] = []
            by_dtype[r["dtype"]].append(r)
            
        for dtype in sorted(by_dtype.keys()):
            items = by_dtype[dtype]
            items.sort(key=lambda x: x["bucket"])
            x_vals = [r["bucket"] for r in items]
            y_vals = [r["epiplexity"] for r in items]
            
            c = type_colors.get(dtype, 'black')
            ax.plot(x_vals, y_vals, marker='o', label=type_map.get(dtype, dtype), 
                    color=c, linewidth=4, markersize=8)
            
        ax.set_ylabel("Epiplexity")
        ax.legend(loc='lower right')
        ax.grid(True, alpha=0.15)
    
    plot_on_ax(axes[0], [r for r in recs_target if r["dtype"] == 'f'])
    plot_on_ax(axes[1], [r for r in recs_target if r["dtype"] in ['i', 'o']])

    axes[1].set_xlabel("Self-Play Iteration")

    plt.tight_layout()
    plt.savefig(filename)
    plt.savefig(filename.replace(".png", ".pdf"))
    plt.close()

def plot_buckets_bar_dual_axis(filename, target_sz="3b"):
    plot_records = []
    for test_sz, dtype, bucket_id, data_list in raw_bucket_data:
        best_entry = min(data_list, key=lambda x: x.get("mdl_per_token", float('inf')))
        epiplexity = best_entry.get("epiplexity")
        if epiplexity is not None:
             plot_records.append({"test_sz": test_sz, "dtype": dtype, "bucket": bucket_id, "epiplexity": epiplexity})
             
    recs = [r for r in plot_records if r["test_sz"] == target_sz and r["bucket"] != 0]
    if not recs: return

    buckets = sorted(list(set(r["bucket"] for r in recs)))
    data_by_type = {'f': [], 'i': [], 'o': []}
    for b in buckets:
        for dt in ['f', 'i', 'o']:
            val = next((r["epiplexity"] for r in recs if r["bucket"] == b and r["dtype"] == dt), 0)
            data_by_type[dt].append(val)

    fig, ax1 = plt.subplots(figsize=(18, 10))
    ax2 = ax1.twinx()
    
    width = 0.25
    x = np.arange(len(buckets))
    
    ax1.bar(x - width, data_by_type['i'], width, label='Abduction', color='#ff7f0e', alpha=0.8)
    ax1.bar(x, data_by_type['o'], width, label='Deduction', color='#2ca02c', alpha=0.8)
    ax2.bar(x + width, data_by_type['f'], width, label='Induction', color='#1f77b4', alpha=0.8)
    
    ax1.plot(x - width, data_by_type['i'], color='#ff7f0e', marker='o', linewidth=2, markersize=6)
    ax1.plot(x, data_by_type['o'], color='#2ca02c', marker='s', linewidth=2, markersize=6)
    ax2.plot(x + width, data_by_type['f'], color='#1f77b4', marker='^', linewidth=2, markersize=6)

    ax1.set_xlabel('Self-Play Iteration', fontsize=24)
    ax1.set_ylabel('Epiplexity (Abduction & Deduction)', fontsize=24)
    ax2.set_ylabel('Epiplexity (Induction)', fontsize=24)
    ax1.set_xticks(x)
    ax1.set_xticklabels(buckets)
    ax1.set_title(f"Epiplexity Trends (Solver: {target_sz})", fontsize=28, pad=20)
    ax1.grid(True, axis='y', alpha=0.2)
    
    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax1.legend(h1 + h2, l1 + l2, loc='upper left', frameon=True, edgecolor='black')
    
    plt.tight_layout()
    plt.savefig(filename)
    plt.savefig(filename.replace(".png", ".pdf"))
    plt.close()

# Generateized charts
plot_buckets_chart("chart2_bucket_trends.png", "3b")
plot_buckets_bar_dual_axis("chart2_bucket_bar_dual.png", "3b")

print("Charts generated successfully.")
