import os

import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use("seaborn")
sns.set_style("whitegrid")
sns.set_palette("deep")

PARTIAL_OBS_DIR = 'paper/llama_partial_obs'

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig):
    plt.figure()

    run_ids = os.listdir(PARTIAL_OBS_DIR)
    for run_id in run_ids:
        run_dir = os.path.join(PARTIAL_OBS_DIR, run_id)
        run_cfg = OmegaConf.load(os.path.join('models', run_id, 'cfg.omega'))
        
        xs = []
        ys = []
        if run_cfg['network']['use_observation'] and run_cfg['network']['use_message'] and run_cfg['network']['use_inventory']:
            label = 'Fully observed'
        elif not run_cfg['network']['use_inventory']:
            label = '- Inventory'

        run_loss_files = os.listdir(run_dir)
        for loss_file in run_loss_files:
            data = torch.load(os.path.join(run_dir, loss_file))
            xs.append(data['samples'])
            ys.append(data['eval_loss'])

        if not run_cfg['network']['use_observation'] or not run_cfg['network']['use_message']:
            continue

        # sort by xs
        xs, ys = zip(*sorted(zip(xs, ys)))

        plt.plot(
            xs, 
            ys, 
            label=label, 
            markersize=9.8,
            linewidth=3,
            markeredgewidth=1.6,
            markeredgecolor="#F7F7FF",
            marker='o'
        )

    plt.xlabel('Samples', fontsize="20")
    plt.ylabel('Dev Loss', fontsize="20")
    plt.yticks(fontsize="17")
    plt.xticks(fontsize="17")

    # handles, labels = plt.gca().get_legend_handles_labels()
    # order = [1, 2, 0, 4, 5, 6, 3]
    # plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], frameon=True)
    plt.legend(frameon=True, fontsize="14", loc="upper right")

    plt.savefig('paper/figures/llama_partial_obs.pdf')
        
    
if __name__ == "__main__":
    main()