import os

import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use("seaborn")
sns.set_style("whitegrid")
sns.set_palette("deep")

EXP_DIR = 'paper/llama_depth_vs_width'

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig):
    plt.figure()

    size_to_run_ids = {
        '40M': [],
        '68M': [],
        '105M': [],
        # '144M': []
    }

    run_ids = os.listdir(EXP_DIR)
    for run_id in run_ids:
        run_dir = os.path.join(EXP_DIR, run_id)
        run_loss_file = os.listdir(run_dir)[0]
        data = torch.load(os.path.join(run_dir, run_loss_file))
        size = data['params']

        if abs(size - 4e7) < 5e6:
            size_to_run_ids['40M'].append(run_id)
        elif abs(size - 6.8e7) < 5e6:
            size_to_run_ids['68M'].append(run_id)
        elif abs(size - 1.05e8) < 5e6:
            size_to_run_ids['105M'].append(run_id)
        # elif abs(size - 1.44e8) < 5e6:
        #     size_to_run_ids['144M'].append(run_id)

    for label, run_ids in size_to_run_ids.items():
        xs = []
        ys = []
        label = f"{label} params"

        for run_id in run_ids:
            run_dir = os.path.join(EXP_DIR, run_id)
            run_cfg = OmegaConf.load(os.path.join('models', run_id, 'cfg.omega'))

            run_loss_file = os.listdir(run_dir)[0]
            data = torch.load(os.path.join(run_dir, run_loss_file))
            xs.append(run_cfg['network']['hdim'] / run_cfg['network']['tf_num_layers'])
            ys.append(data['eval_loss'])

        # sort by xs
        xs, ys = zip(*sorted(zip(xs, ys)))

        plt.plot(
            xs, 
            ys, 
            label=label, 
            markersize=9.2,
            linewidth=2,
            markeredgewidth=1.6,
            markeredgecolor="#F7F7FF",
            marker='o'
        )

    plt.xlabel('Aspect Ratio', fontsize="20")
    plt.ylabel('Dev Loss', fontsize="20")
    plt.yticks(fontsize="17")
    plt.xticks(fontsize="17")

    # handles, labels = plt.gca().get_legend_handles_labels()
    # order = [1, 2, 0, 4, 5, 6, 3]
    # plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], frameon=True)
    plt.legend(frameon=True, fontsize="14")
    plt.tight_layout()


    plt.savefig('paper/figures/llama_depth_vs_width.pdf')
        
    
if __name__ == "__main__":
    main()