import logging
import os

import numpy as np
import torch
from tqdm import tqdm

import utils
from hyperparams.load import get_config
from mhvae_vasco.objective.load import get_objective

logger = logging.getLogger('custom')
config = get_config()


class Trainer:
    eval_freq = config.eval_freq
    save_freq = config.save_freq

    def __init__(self,
                 model,
                 loader,
                 args,
                 run_path,
                 evaluators,
                 device,
                 debug=False,
                 checkpoint=None):
        self.model = model
        self.loader = loader
        self.args = args
        self.run_path = run_path
        self.evaluators = evaluators
        self.device = device
        self.debug = debug
        self.objective = get_objective(args.dset_name)
        self._prepare(args, checkpoint)

    def _prepare(self, args, checkpoint):
        trainable_params = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        self.opt = torch.optim.Adam(trainable_params, args.lr)
        if checkpoint:
            logger.info('\nResuming checkpoint.')
            self.opt.load_state_dict(checkpoint['optimizer'])
            self.start_epoch = checkpoint['epoch'] + 1
        else:
            self.start_epoch = 1

        self.epochs = args.epochs
        self.kl_end_warmup_z = args.kl_end_warmup_z
        self.kl_end_warmup_g = args.kl_end_warmup_g

    def train(self):
        self._spawn_wandb()
        for epoch in range(self.start_epoch, self.epochs + 1):
            self.run_epoch(epoch)
            if any([epoch % self.eval_freq == 0]):
                self._validate_epoch(epoch)
                self._spawn_wandb()
            if any([epoch % self.save_freq == 0, epoch == self.epochs]):
                self._save_checkpoint(epoch, f'model_epoch_{epoch}.pt')

    def run_epoch(self, epoch):
        with utils.Timer(f'Epoch: {epoch:03d}', event_frequency='medium'):
            self.model.train()
            beta = {'z': min(1, epoch / self.kl_end_warmup_z),
                    'g': min(1, epoch / self.kl_end_warmup_g)}
            total_loss = []

            for inp in tqdm(self.loader, desc='Training'):
                inp = utils.to_device(inp, self.device)
                x, _ = inp
                output = self.model.forward(x)
                loss = self.objective(self.model, inp, output, beta)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                total_loss.append(loss.item())

            logger.info(f'Loss: {np.average(total_loss):.1f}')

    def _validate_epoch(self, epoch):
        for split, evaluator in self.evaluators.items():
            evaluator.evaluate(self.model, epoch)
        self._save_checkpoint(epoch, f'model_epoch_{epoch}.pt')

    def _save_checkpoint(self, epoch, filename):
        checkpoint = {
            'state_dict': self.model.state_dict(),
            'epoch': epoch,
            'optimizer': self.opt.state_dict()
        }
        dst = os.path.join(self.run_path, 'models')
        os.makedirs(dst, exist_ok=True)
        dst = os.path.join(dst, filename)
        torch.save(checkpoint, f=dst)
        logger.info(f'Saved checkpoint at {dst}.')

    def _spawn_wandb(self):
        utils.init_wandb(
            run_id=self.args.run_id,
            project='hmvae_images',
            group='train',
            wandb_config=vars(self.args),
            name=f'{self.args.model}_{self.args.trial}',
            tags=['debug'] if self.debug else None,
        )
