import os
from pathlib import Path

import torch

import utils
from hyperparams.load import get_config

config = get_config()


class Evaluator:
    def __init__(self, run_path, split, args, device, debug=False):
        # Extract run id from run path to comply with ID of the form f'{ID}_debug'
        self.run_id = Path(run_path).name
        self.split_dir = os.path.join(run_path, split)
        os.makedirs(self.split_dir, exist_ok=True)
        self.split = split
        self.args = args
        self.device = device
        self.debug = debug
        self.artifact_dir = None  # Define when calling evaluation method

    def _update_quant_results(self, log, epoch):
        path = os.path.join(self.split_dir, 'quant_results.pt')
        data = torch.load(path) if os.path.isfile(path) else {}
        if epoch not in data:
            data[epoch] = log
        else:
            data[epoch] = utils.update(data[epoch], log)
        torch.save(data, path)

    def _prepare_evaluation_run(self, epoch):
        self.set_artifact_dir(epoch)
        utils.init_wandb(
            run_id=self.args.run_id,
            project=f'hmvae_images',
            group=self.split,
            wandb_config=vars(self.args),
            name=f'{self.args.model}_{self.args.trial}',
            tags=['debug'] if self.debug else None,
        )

    def set_artifact_dir(self, epoch):
        """ Prepare dir that gathers experimental artifacts. """
        self.artifact_dir = os.path.join(self.split_dir, f'epoch_{epoch}')
        os.makedirs(self.artifact_dir, exist_ok=True)
