import logging
import os

import numpy as np
import torch
import wandb
from tqdm import tqdm

import utils
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class Trainer:
    def __init__(self, args, model, checkpoint=None, debug=False):
        """
        :param debug: True for runs solely used for debugging purposes.
        """
        self.args = args
        self.model = model
        self.debug = debug
        self.run_path = None

        trainable_params = filter(lambda p: p.requires_grad, model.parameters())
        self.opt = torch.optim.Adam(trainable_params, args.lr)
        if checkpoint:
            logger.info('\nResuming checkpoint.')
            self.opt.load_state_dict(checkpoint['optimizer'])
            self.start_epoch = checkpoint['epoch'] + 1
        else:
            self.start_epoch = 1

    def _save_checkpoint(self, epoch, filename):
        checkpoint = {
            'state_dict': self.model.state_dict(),
            'epoch': epoch,
            'optimizer': self.opt.state_dict(),
        }
        dst = os.path.join(self.run_path, 'models', filename)
        torch.save(checkpoint, dst)
        logger.info(f'Saved checkpoint at {dst}.')

    def train(self, *args, **kwargs):
        raise NotImplementedError

    def _run_epoch(self, *args, **kwargs):
        raise NotImplementedError

    def _spawn_wandb(self):
        project = utils.get_wandb_project_name(self.args.dset_name)
        utils.init_wandb(
            run_id=self.args.run_id,
            project=project,
            group='train',
            wandb_config=vars(self.args),
            name=f'{self.args.model}_{self.args.trial}',
            tags=['debug'] if self.debug else None,
        )


class VaeTrainer(Trainer):
    """ Trains variational autoencoders. """

    def __init__(self, args, config, objective, loader, evaluators, device, **kwargs):
        super().__init__(args, **kwargs)
        self.objective = objective(args)
        # The following datastructures are dicts over splits
        self.loader = loader
        self.evaluators = evaluators
        self.device = device

        self.end_epoch = args.epochs
        self.kl_factor = args.kl_factor
        self.kl_end_warmup = args.kl_end_warmup
        self.run_path = config.run_path
        self.eval_freq = config.eval_freq
        self.save_freq = config.save_freq

        logger.info(f'\nRun path:\n{self.run_path}')
        os.makedirs(os.path.join(self.run_path, 'models'), exist_ok=True)

    def train(self, k: int):
        """
        :param k: number of importance samples per batch sample
        """
        self._spawn_wandb()
        for epoch in range(self.start_epoch, self.end_epoch + 1):
            with utils.Timer(f'Epoch: {epoch:03d}', event_frequency='medium'):
                self._run_epoch(epoch, k)
                if any([epoch % self.eval_freq == 0]):
                    self._validate_epoch(epoch)
                    self._spawn_wandb()
                if any([epoch % self.save_freq == 0, epoch == self.end_epoch]):
                    self._save_checkpoint(epoch, f'model_epoch_{epoch}.pt')
        self._finish_training()

    def _run_epoch(self, epoch, k):
        self.model.train()
        meter = utils.Meter()
        beta = self.kl_factor * min(1, epoch / self.kl_end_warmup)

        for inp in tqdm(self.loader, desc='Training'):
            inp = utils.to_device(inp, self.device)
            x, _ = inp
            output = self.model.forward(x, k=k)
            loss, diagnostics = self.objective(
                self.model, data={'inp': inp, 'output': output}, beta=beta)
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            meter.add_dict(log=diagnostics)

        self._finish_epoch(meter, epoch)

    def _finish_epoch(self, meter, epoch):
        log = meter.flush()
        if np.isnan(log['loss']['total']):
            raise RuntimeError('Encountered nan-loss, abort.')
        wandb.log(log, step=epoch)
        f = lambda x: {f'{k}': np.around(v, 4) for k, v in x.items()}
        logger.info(f'- Loss (train): {f(log["loss"])}')

    def _validate_epoch(self, epoch):
        for split, evaluator in self.evaluators.items():
            evaluator.evaluate(model=self.model, epoch=epoch)
        self._save_checkpoint(epoch, f'model_epoch_{epoch}.pt')

    def _finish_training(self):
        logger.info('\nModel has been trained.')
        utils.close_logger()
        wandb.finish()


def define_trainer(package, **kwargs):
    objective = getattr(package, 'Objective')
    trainer = VaeTrainer(objective=objective, **kwargs)
    return trainer
