import os

import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use("seaborn")
sns.set_style("whitegrid")
sns.set_palette("deep")

EXP_DIR = 'paper/llama_lr_sensitivity'

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig):
    plt.figure()

    schedule_to_run_ids = {
        'Constant (0)': [],
        'Constant (1000)': [],
        'Constant (2000)': [],
        'Constant (3000)': [],
        'Cosine (1000)': [],
        'Cosine (2000)': [],
        'Cosine (3000)': [],
        # 'Cosine (5000)': []
    }

    run_ids = os.listdir(EXP_DIR)
    for run_id in run_ids:
        run_cfg = OmegaConf.load(os.path.join('models', run_id, 'cfg.omega'))
        schedule = run_cfg['optimizer']['scheduler_type']
        warmup = run_cfg['optimizer']['optim_warmup_steps']

        if schedule == 'constant' and warmup == 0:
            schedule_to_run_ids['Constant (0)'].append(run_id)
        elif schedule == 'constant' and warmup == 1000:
            schedule_to_run_ids['Constant (1000)'].append(run_id)
        elif schedule == 'constant' and warmup == 2000:
            schedule_to_run_ids['Constant (2000)'].append(run_id)
        elif schedule == 'constant' and warmup == 3000:
            schedule_to_run_ids['Constant (3000)'].append(run_id)
        elif schedule == 'cosine' and warmup == 1000:
            schedule_to_run_ids['Cosine (1000)'].append(run_id)
        elif schedule == 'cosine' and warmup == 2000:
            schedule_to_run_ids['Cosine (2000)'].append(run_id)
        elif schedule == 'cosine' and warmup == 3000:
            schedule_to_run_ids['Cosine (3000)'].append(run_id)
        # elif schedule == 'cosine' and warmup == 5000:
        #     schedule_to_run_ids['Cosine (5000)'].append(run_id)

    for label, run_ids in schedule_to_run_ids.items():
        xs = []
        ys = []

        for run_id in run_ids:
            run_dir = os.path.join(EXP_DIR, run_id)
            run_cfg = OmegaConf.load(os.path.join('models', run_id, 'cfg.omega'))

            run_loss_file = os.listdir(run_dir)[0]
            data = torch.load(os.path.join(run_dir, run_loss_file))
            xs.append(run_cfg['optimizer']['lr'])
            ys.append(data['eval_loss'])

        # sort by xs
        xs, ys = zip(*sorted(zip(xs, ys)))

        plt.xscale("log")
        plt.yscale("log")

        plt.yticks(ticks=[0.2, 0.22, 0.24, 0.28, 0.32, 0.36, 0.4, 0.44], labels=["0.2", "0.22", "0.24", "0.28", "0.32", "0.36", "0.4", "0.44"])
        plt.xticks(ticks=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3], labels=['1e-5', '2e-5', '5e-5', '1e-4', '2e-4', '5e-4', '1e-3'])

        plt.plot(
            xs, 
            ys, 
            label=label, 
            markersize=9.2,
            linewidth=2,
            markeredgewidth=1.6,
            markeredgecolor="#F7F7FF",
            marker='o'
        )

    plt.xlabel('Learning Rate', fontsize="20")
    plt.ylabel('Dev Loss', fontsize="20")
    plt.yticks(fontsize="17")
    plt.xticks(fontsize="17")
    plt.minorticks_off()

    # handles, labels = plt.gca().get_legend_handles_labels()
    # order = [1, 2, 0, 4, 5, 6, 3]
    # plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], frameon=True)
    plt.legend(frameon=True, fontsize="14")
    plt.tight_layout()

    plt.savefig('paper/figures/llama_lr_sensitivity.pdf')
        
    
if __name__ == "__main__":
    main()