import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from pathlib import Path
import math
import re

sns.set(style="darkgrid", font_scale=3)

loss_logs = [
    "out_files/red_pajama_base_410m.out",
    "out_files/red_pajama_sqrt_log_410m.out",
    "out_files/red_pajama_log_410m.out",
    "out_files/red_pajama_log_squared_410m.out",
]

csv_base = "csv_logs/model_2026-01-16 04:26:46.211267/scalars/"
csv_sqrt_log = "csv_logs/model_2026-01-16 06:18:06.193117/scalars/"
csv_log = "csv_logs/model_2026-01-16 04:58:43.313344/scalars/"
csv_log2 = "csv_logs/model_2026-01-16 04:26:46.211268/scalars/"

names = [r"$T(L) = 1$", r"$T(L) = \sqrt{\log L}$", r"$T(L) = \log L$", r"$T(L) = (\log L)^2$"]
csv_list = [csv_base, csv_sqrt_log, csv_log, csv_log2]

rows = [8, 512]
colors = ["C0", "C1", "C2", "C3"]

pattern = re.compile(r"iter\s+(\d+)\s+step\s+(\d+):\s+loss\s+([0-9.eE+-]+),\s*iter time:\s*([0-9.eE+-]+)ms")

def parse(path):
    iters, losses = [], []
    with path.open("r") as f:
        for line in f:
            m = pattern.search(line)
            if m:
                iters.append(int(m.group(1)))
                losses.append(float(m.group(3)))
    return np.asarray(iters), np.asarray(losses)

def mv_average(x, window_size):
    return x.rolling(window=window_size).mean()

def find_kl_file(folder, row):
    folder = Path(folder)
    file_name = folder / f"Row: {row}/kl_to_uniform.csv"
    if file_name.exists():
        return str(file_name)
    hits = list(folder.glob(f"*Row*{row}*kl_to_uniform*.csv"))
    if len(hits) > 0:
        return str(hits[0])
    hits = list(folder.rglob(f"*Row*{row}*kl_to_uniform*.csv"))
    if len(hits) > 0:
        return str(hits[0])

label_x_targets = {0: 15000, 1: 25000, 2: 30000, 3: 5000}
label_dx_per_row = 8000
window_size = 15000
window_size1 = 50000
loss_window_size = 500

fig = plt.figure(figsize=(30, 20))
gs = fig.add_gridspec(2, 2, height_ratios=[1, 1.5], hspace=0.2, wspace=0.15)

ax_loss = fig.add_subplot(gs[0, :])
ax1 = fig.add_subplot(gs[1, 0])
ax2 = fig.add_subplot(gs[1, 1])

for h, loss_path in enumerate(loss_logs):
    iters, losses = parse(Path(loss_path))
    loss_ma = pd.Series(losses).rolling(window=loss_window_size).mean().to_numpy()
    ax_loss.plot(iters, loss_ma, c=colors[h], lw=2, label=names[h])

ax_loss.set_ylabel("Training Loss")
ax_loss.set_xlabel("Training Iteration")
ax_loss.grid(True, which="both")
ax_loss.legend(loc="center right")
ax_loss.set_ylim(3, 7.2)

h = 0
for csv_folder in csv_list:
    for row in rows:
        row_file = Path(csv_folder) / f"Row: {row}.csv"
        if not row_file.exists():
            continue

        df = pd.read_csv(str(row_file))
        length = len(df[df.columns])
        col1 = df[df.columns[0]]
        col2 = df[df.columns[1]]
        mv_average_plot = mv_average(col2, window_size=window_size1)

        x_vals = np.asarray(col1)
        y_vals = np.asarray(mv_average_plot)

        label = names[h] if row == 512 else None
        ax1.plot(x_vals, y_vals, c=colors[h], lw=2, label=label)

        x_min = np.nanmin(x_vals)
        x_max = np.nanmax(x_vals)
        x_target = label_x_targets.get(h, 20000) + (0 if row == 512 else label_dx_per_row)
        x_target = np.clip(x_target, x_min, x_max)
        idx = int(np.nanargmin(np.abs(x_vals - x_target)))

        ax1.text(
            x_vals[idx],
            y_vals[idx],
            f"Row {row}",
            fontsize=18,
            ha="center",
            va="center",
            clip_on=True,
            color=colors[h],
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.02),
        )
    h += 1

h = 0
for csv_folder in csv_list:
    for row in rows:
        kl_file = find_kl_file(csv_folder, row)
        if kl_file is None:
            continue

        df_kl = pd.read_csv(kl_file)
        length_kl = len(df_kl[df_kl.columns])
        kl_x = df_kl[df_kl.columns[0]]
        kl_y = df_kl[df_kl.columns[1]]

        kl_ma = mv_average(kl_y, window_size=window_size)
        norm_kl_ma = np.asarray(kl_ma) / math.log(row)

        x_vals = np.asarray(kl_x)
        y_vals = np.asarray(norm_kl_ma)

        label = names[h] if row == 512 else None
        ax2.plot(x_vals, y_vals, c=colors[h], lw=2, label=label)

        x_min = np.nanmin(x_vals)
        x_max = np.nanmax(x_vals)
        x_target = label_x_targets.get(h, 20000) + (0 if row == 512 else label_dx_per_row)
        x_target = np.clip(x_target, x_min, x_max)
        idx = int(np.nanargmin(np.abs(x_vals - x_target)))

        ax2.text(
            x_vals[idx],
            y_vals[idx],
            f"Row {row}",
            fontsize=26,
            ha="center",
            va="center",
            clip_on=True,
            color=colors[h],
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.02),
        )
    h += 1

ax1.set_yscale("log")
ax1.set_ylabel(r"$\bar{d_A}(\mathbf{p}, \mathbf{u})$")
ax1.grid(True, which="both")
ax1.set_xlabel("Training Iteration")

ax2.set_ylabel(r"$\overline{\mathrm{KL}}(\mathbf{p}\,\|\,\mathbf{u})$")
ax2.set_ylim(-0.02, 1.02)
ax2.grid(True, which="both")
ax2.set_xlabel("Training Iteration")

vline_x = 20000
for ax in (ax_loss, ax1, ax2):
    ax.axvline(vline_x, color="k", linestyle="--", linewidth=2, alpha=0.8)

plt.savefig(
    "a_plotting_directory/temperature_scale/temperature_scaling_three_plots.png",
    bbox_inches="tight",
)