import os

os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import fire
import matplotlib.pyplot as plt
import pandas as pd

from meta_alignment.config import TrainingConfig
from meta_alignment.dataset import get_dataset
from meta_alignment.inference import generate_completions


def main(bon_type="bon"):
    args = TrainingConfig().from_dict(
        {
            "task": "length",
            "model": "mistral7b",
        }
    )
    if bon_type == "bon":
        model_dirs = [
            # args.model_dir,
            "./results/length/models/baseline",
            "./results/length/models/iama_bon_2",
            "./results/length/models/iama_bon_4",
            # "./results/length/models/iama_bon_8",
            # "./results/length/models/iama_bon_16",
        ]
        labels = [
            # "Reference",
            "Baseline",
            "IAMA (N=2)",
            "IAMA (N=4)",
            # "IAMA (N=8)",
            # "IAMA (N=16)",
        ]
    elif bon_type == "softbon":
        model_dirs = [
            # args.model_dir,
            "./results/length/models/baseline",
            "./results/length/models/iama_sbon_0.1",
            "./results/length/models/iama_sbon_0.09",
            # "./results/length/models/iama_sbon_0.08",
            # "./results/length/models/iama_sbon_0.07",
        ]
        labels = [
            # "Reference",
            "Baseline",
            r"IAMA ($\lambda=0.1$)",
            r"IAMA ($\lambda=0.09$)",
            # r"IAMA ($\lambda=0.08$)",
            # r"IAMA ($\lambda=0.07$)",
        ]
    else:
        raise ValueError(f"Unknown bon_type: {bon_type}")

    def get_completion_length_distribution(model_dir: str) -> list[int]:
        res_path = f"results/length/eval_completions/{model_dir.split('/')[-1]}.jsonl"
        if os.path.exists(res_path):
            data = pd.read_json(res_path, lines=True)
            return data["length"].tolist()

        _, eval_dataset = get_dataset(args, train_size=0, eval_size=-1)
        prompts = [eval_dataset[i]["prompt"] for i in range(len(eval_dataset))]
        completions = generate_completions(
            model_dir,
            prompts,
        )
        completion_lengths = [len(completion) for completion in completions]
        data = pd.DataFrame(
            {
                "prompt": prompts,
                "completion": completions,
                "length": completion_lengths,
            }
        )
        os.makedirs("results/length/eval_completions", exist_ok=True)
        data.to_json(res_path, lines=True, orient="records")
        return completion_lengths

    bins = range(0, 256, 5)
    with plt.style.context("./config/paper.mplstyle"):
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        for i, model_dir in enumerate(model_dirs):
            completion_lengths = get_completion_length_distribution(model_dir)
            ax.hist(
                completion_lengths,
                bins=bins,
                alpha=0.3,
                label=labels[i],
                density=True,
            )

        ax.set_title("Distribution of Completion Lengths")
        ax.set_xlabel("\# of Tokens")
        ax.set_ylabel("Density")
        ax.set_ylim(0, 0.03)
        ax.legend()

    fig.savefig(
        f"./figs/length/distribution_of_completion_length_{bon_type}.png",
        bbox_inches="tight",
    )
    plt.show()


if __name__ == "__main__":
    fire.Fire(main)
