import logging
import os
from pathlib import Path

import torch
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

import utils

logger = logging.getLogger('custom')
config = utils.get_config()


class Evaluator:

    def __init__(self, split, args, config, device, debug=False):
        """
        :param debug: True for runs solely used for debugging purposes.
        """
        # Extract run id from run path to comply with ID of the form f'{ID}_debug'
        self.run_id = Path(config.run_path).name
        self.split = split
        self.device = device
        self.debug = debug
        self.eval_bs = args.eval_bs['final'] if isinstance(args.eval_bs, dict) else args.eval_bs
        self.args = args
        self.eval_k = config.eval_k
        self.modalities = 'not_initialized'
        self.run_path = config.run_path
        self.result_flags = None  # Define in subclass
        self.split_dir = self._setup_split_dir()
        self.artifact_dir = None  # Define when calling evaluation method

    def _update_quant_results(self, log, epoch):
        path = os.path.join(self.split_dir, 'quant_results.pt')
        data = torch.load(path) if os.path.isfile(path) else {}
        if epoch not in data:
            data[epoch] = log
        else:
            data[epoch] = utils.update(data[epoch], log)
        torch.save(data, path)

    def _setup_split_dir(self):
        """ Prepare dir for placing split-specific experimental artifacts. """
        split_dir = os.path.join(self.run_path, self.split)
        os.makedirs(split_dir, exist_ok=True)
        return split_dir

    def set_artifact_dir(self, epoch):
        """ Prepare dir that gathers experimental artifacts. """
        self.artifact_dir = os.path.join(self.split_dir, f'epoch_{epoch}')
        os.makedirs(self.artifact_dir, exist_ok=True)

    def _prepare_evaluation_run(self, epoch):
        self.set_artifact_dir(epoch)
        project = utils.get_wandb_project_name(self.args.dset_name)
        utils.init_wandb(
            run_id=self.args.run_id,
            project=project,
            group=self.split,
            wandb_config=vars(self.args),
            name=f'{self.args.model}_{self.args.trial}',
            tags=['debug'] if self.debug else None,
        )

    @torch.no_grad()
    def evaluate(self, model, epoch):
        model.eval()
        self._init_modalities_property(model)
        self._estimate_likelihoods_wrapper(model, epoch)

    def _load_data(self, *args, **kwargs):
        raise NotImplementedError('Define in subclass.')

    def _estimate_likelihoods_wrapper(self, model, epoch):
        if not self.result_flags.get('estimate_likelihoods'):
            return None

        with utils.Timer(f'Likelihood Approximation '
                         f'(k={self.eval_k}; {self.split})'):

            dataset = self._load_data(average=True)
            log = self._estimate_likelihoods(
                model=model,
                dataset=dataset,
                k=self.eval_k,
                bs=self.eval_bs)

            logger.info(f'Likelihoods:')
            for k, v in log[f'likelihood_k_{self.eval_k}/'].items():
                logger.info(f'{k}: {v:.2f}')
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)

    def _estimate_likelihoods(self,
                              model,
                              dataset: Dataset,
                              k: int,
                              bs: int,
                              **kwargs):
        """ Compute test likelihood. """
        model.eval()
        meter = utils.Meter()
        loader = DataLoader(dataset, batch_size=bs, shuffle=False)

        for inp in tqdm(loader, position=0, leave=True,
                        desc=f'Lik (k={k}, bs={bs})'):
            inp = utils.to_device(inp, self.device)
            x, _ = inp
            output = model.forward(x, eval=True, k=k)
            likelihood = model.evaluate_likelihood(inp, output, **kwargs)
            meter.add_dict(likelihood, f'likelihood_k_{k}/')

        logs = meter.flush()
        return logs

    def _init_modalities_property(self, model):
        if self.modalities == 'not_initialized':
            if hasattr(model, 'modalities'):
                self.modalities = model.modalities
            else:
                # Assume unimodal vae
                self.modalities = ['x1']
