import os

import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use("seaborn")
sns.set_style("whitegrid")
# sns.set_palette("deep")
colors = list(sns.color_palette('Blues', 7))

len_to_color = {
    128: colors[0],
    256: colors[1],
    512: colors[2],
    1024: colors[3],
    2048: colors[4],
    4096: colors[5],
    8192: colors[6]
}

CTX_LEN_DIR = 'paper/llama_effective_ctx_length'

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig):
    plt.figure()

    run_ids = os.listdir(CTX_LEN_DIR)
    for run_id in run_ids:
        run_dir = os.path.join(CTX_LEN_DIR, run_id)
        run_cfg = OmegaConf.load(os.path.join('models', run_id, 'cfg.omega'))
        
        xs = []
        ys = []
        label = f"Ctx len {run_cfg['data']['unroll_length']}"

        run_loss_files = os.listdir(run_dir)
        for loss_file in run_loss_files:
            data = torch.load(os.path.join(run_dir, loss_file))
            xs.append(data['samples'])
            ys.append(data['eval_loss'])

        # sort by xs
        xs, ys = zip(*sorted(zip(xs, ys)))

        # 4K
        # (0.2533087134361267, 0.20105645060539246, 0.18379859626293182, 0.17474332451820374, 0.170623779296875)
        # 8K
        # (0.2669707238674164, 0.20453573763370514, 0.18468832969665527, 0.17483405768871307, 0.17055025696754456)

        plt.plot(
            xs, 
            ys, 
            label=label, 
            markersize=9.2,
            linewidth=3,
            markeredgewidth=1.6,
            markeredgecolor="#F7F7FF",
            marker='o',
            color=len_to_color[run_cfg['data']['unroll_length']]
        )

    plt.xlabel('Samples', fontsize="20")
    plt.ylabel('Dev Loss', fontsize="20")
    plt.yticks(fontsize="17")
    plt.xticks(fontsize="17")

    handles, labels = plt.gca().get_legend_handles_labels()
    order = [1, 2, 0, 4, 5, 6, 3]
    plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], frameon=True, fontsize="14", loc="upper right")

    plt.savefig('llama_effective_ctx_len.pdf')
        
    
if __name__ == "__main__":
    main()