from typing import Optional, List
from dataclasses import dataclass
from pathlib import Path
import pickle
import re

import tyro
import numpy as np
import matplotlib.pyplot as plt


@dataclass
class Config:
    path: str = "./stats/outputs/1_train3_Llama-3-2-1B-Instruct_digit.pkl"


def draw_outputs(data, hypos, gts, root):
    for i, (datum, hypo, gt) in enumerate(zip(data, hypos, gts)):
        try:
            draw_output(datum, hypo, gt, f"{root}/plot_{i}.svg")
        except Exception as e:
            print(e)
        if i >= 50:
            break


def get_num(text: str) -> List[float]:
    nums = re.findall(r"(\d+\.\d+)", text)
    ys = [float(n) for n in nums]
    return ys


def parse_context(prompt):
    """
    prompt = ", ".join([f"({x}, {y})" for x, y in hypos])
    prompt = (
        "Given the following data points, find the next point: "
        + prompt
        + ". Inputs for the next data points are: "
        + f"[{xs}]"
    )
    """
    nums = get_num(prompt)
    context, x = nums[:-1], nums[-1]
    context = [(context[i], context[i + 1]) for i in range(0, len(context), 2)]
    return context, x


def draw_output(ann, hypos, gt, path, fontsize: int = 20):
    train_data, x = parse_context(ann["conv"][0])

    x_vals, y_vals = zip(*train_data)
    # Define the line using the derived analytic form
    # Selecting two points from train_data
    (x1, y1), (x2, y2) = train_data[:2]

    # Calculating slope and intercept manually
    m_manual = (y2 - y1) / (x2 - x1)
    b_manual = y1 - m_manual * x1

    x_line = np.linspace(0, 1, 100)
    y_line_manual = m_manual * x_line + b_manual

    # Plotting
    plt.figure(figsize=(6, 5))

    # Plot the derived line using two points from train_data
    plt.plot(
        x_line,
        y_line_manual,
        color="black",
        linestyle="-",
        linewidth=1,
    )
    plt.axvline(x=x, color="gray", linestyle="--")

    # Plot the original train data points
    plt.scatter(
        x_vals,
        y_vals,
        color="black",
        marker="x",
        zorder=5,
        s=100,
    )

    colors = ["#4B6A94", "#B2182B", "#1A936F"]
    markers = ["s", "d", "o", "P", "^", "v", "<", ">"]

    keys = ["sft", "digit_base", "digit"]
    for i, k in enumerate(keys):
        v = hypos[k]
        if k == "digit_base":
            k = "vocab"
        color = colors[i]
        marker = markers[i]
        plt.scatter(
            x,
            float(v[1:-1]),
            color=color,
            label=k,
            s=200,
            marker=marker,
            zorder=5,
        )

    plt.ylim(0, 1)
    plt.legend(
        fontsize=fontsize,
        loc="lower right",
    )
    plt.grid(False)
    plt.xticks(fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    plt.tight_layout()

    plt.savefig(path)
    plt.close()


def main():
    args = tyro.cli(Config)

    path = Path(args.path)
    root = path.parent
    name = path.stem
    *rest, loss = name.split("_")
    rest = "_".join(rest)

    data = None
    hypos = {}
    gts = None
    for _loss in ["sft", "digit_base", "digit"]:
        path = root / f"{rest}_{_loss}.pkl"
        assert path.exists(), f"{path} doesn't exist"
        with open(path, "rb") as f:
            _data = pickle.load(f)
            data = _data["data"]
            hypos[_loss] = _data["hypo"]
            gts = _data["tgt"]

    # rearrange: dict of list to list of dict
    hypos = [{k: v[i] for k, v in hypos.items()} for i in range(len(hypos["sft"]))]
    draw_outputs(data, hypos, gts, root)
    print("done")


if __name__ == "__main__":
    main()
