import json

import hydra
from omegaconf import OmegaConf
from tqdm import tqdm

from lm_understanding import load_model
from lm_understanding.data import dataset_path, model_behavior_path
from lm_understanding.question_template import TemplateDataset


@hydra.main(config_path='config', config_name='model_behavior.yaml', version_base='1.2')
def main(config):
    print(config)
    
    completer, prompter = load_model(config.model)

    data_path = dataset_path(config, new=False)
    save_path = model_behavior_path(config, new=True)
    
    dataset = TemplateDataset.load(data_path)
    pre_dataset_filter_fraction = len(dataset.templates)
    dataset = dataset.filter_from_config(config)
    filter_kept_fraction = len(dataset.templates) / pre_dataset_filter_fraction
    
    template_tasks = dataset.template_tasks
    assert len(template_tasks) == config.dataset_construction.templates_per_topic
    
    for template_task in tqdm(template_tasks):
        assert template_task.template_id
        save_dir = save_path / template_task.template_id
        if save_dir.exists():
            continue
        model_behavior = template_task.evaluate(completer, prompter, train_questions=config.dataset_construction.train_questions_per_template, test_questions=config.dataset_construction.test_questions_per_template)
        model_behavior.save(save_dir)

        run_info = dict(
            config=OmegaConf.to_container(config),
            n_templates_evaluated=len(template_tasks),
            filter_kept_fraction=filter_kept_fraction,
            **completer.info
        )
        with open(save_path / 'run_info.json', 'w') as f:
            json.dump(run_info, f, indent=4, ensure_ascii=False)


if __name__ == '__main__':
    main()