import hydra
from omegaconf import DictConfig, OmegaConf
from trainer.utils import seed_everything
from model import get_model
from evals import get_evaluators
import wandb

@hydra.main(version_base=None, config_path="../configs", config_name="eval.yaml")
def main(cfg: DictConfig):
    """Entry point of the code to evaluate models
    Args:
        cfg (DictConfig): Config to train
    """
    report_to_wandb = cfg.wandb.project is not None
    if report_to_wandb:
        if cfg.wandb.run_id is None:
            print(f"[INFO] Initializing new wandb run on project {cfg.wandb.project}")
            run_name = cfg.wandb.get("run_name", None)
            if run_name is not None:
                print(f"[INFO] Setting wandb run name to {run_name}")
                run = wandb.init(project=cfg.wandb.project, name=run_name)
            else:
                run = wandb.init(project=cfg.wandb.project)
        else:
            print(f"[INFO] Reporting results to run ID {cfg.wandb.run_id} on wandb project {cfg.wandb.project}")
            run = wandb.init(project=cfg.wandb.project, id=cfg.wandb.run_id, resume="must")

        if cfg.wandb.training_args is not None:
            run.config.update(OmegaConf.to_container(cfg.wandb.training_args, resolve=True), allow_val_change=True)

    seed_everything(cfg.seed)
    model_cfg = cfg.model
    template_args = model_cfg.template_args
    assert model_cfg is not None, "Invalid model yaml passed in train config."

    model, tokenizer = get_model(model_cfg)

    eval_cfgs = cfg.eval
    evaluators = get_evaluators(eval_cfgs)
    for evaluator_name, evaluator in evaluators.items():
        eval_args = {
            "template_args": template_args,
            "model": model,
            "tokenizer": tokenizer,
        }
        results = evaluator.evaluate(**eval_args)
        if report_to_wandb:
            wandb.log({f"{evaluator_name}/{k}": v for k, v in results.items()})

if __name__ == "__main__":
    main()
