# -*- coding: utf-8 -*-
import os
import sys
import math
import torch
from trainer.base_trainer import BaseTrainer
from utils import inf_loop
import trainer._utils as utils
import numpy as np
from trainer.geco import GECO

class ModelTrainer(BaseTrainer):
    """ Trainer (single-gpu) """

    def __init__(self, model, loss, metrics, optimizer, config, train_data_loader,
                 valid_data_loader=None, lr_scheduler=None, step_per_epoch=None, device=None):
        super().__init__(model, loss, metrics, optimizer, config, device)
        self.config = config
        self.train_loader = train_data_loader
        self.kl_latent = config.elbo_weights["kl_latent"]    # save target value
        self.kl_global = config.elbo_weights["kl_global"]
        self.kl_comp = config.elbo_weights["kl_component"]
        if step_per_epoch is None:
            # epoch-based training
            self.step_per_epoch = len(self.train_loader)
        elif step_per_epoch:
            # iteration-based training
            self.train_loader = inf_loop(self.train_loader)  # Reusable iterator
            self.step_per_epoch = step_per_epoch

        self.do_validation = valid_data_loader is not None
        if self.do_validation:
            self.valid_loader = valid_data_loader

        self.save_period = config.save_period
        self.lr_scheduler = lr_scheduler

        # setup GECO
        if self.config.geco:
            geco_goal = config.geco_cfg["goal"]
            geco_lr = config.geco_cfg["lr"]
            g_alpha = 0.99
            g_init = config.geco_cfg["min"] 
            g_min = config.geco_cfg["min"]
            g_max = config.geco_cfg["max"]
            g_speedup = config.geco_cfg["speedup"]
            self.geco = GECO(self.device, geco_goal, geco_lr, g_alpha, g_init, g_min, g_max, g_speedup)

        if self.config.resume_epoch is not None:
            self.start_epoch = self.config.resume_epoch + 1
            if self.lr_scheduler is not None:
                self.lr_scheduler.last_epoch = self.config.resume_epoch * self.step_per_epoch
                self.optimizer.param_groups[0]['lr'] = self.lr_scheduler.get_lr()[0]

        # self.save_dir has been declared in parenthese
        assert config.vis_train_dir is not None
        self.vis_train_dir = config.vis_train_dir
        assert config.generated_dir is not None
        self.generated_dir = config.generated_dir

    def train(self):
        print('\n================== Start training ===================')
        # self.start_time = time.time()
        assert (self.epochs + 1) > self.start_epoch
        for epoch in range(self.start_epoch, self.epochs + 1):
            print("Epoch {}".format(epoch))
            self._train_epoch(epoch)

            if epoch % self.config.val_period == 0:
                valid_outs = self._validate_epoch(epoch)
                samples = self._sample(epoch)    # save generation results

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch)

        print('models have been saved to {}'.format(self.save_dir))
        print('================= Training finished =================\n')

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch
        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.
        """
        self.model.train()
        metric_logger = utils.MetricLogger(delimiter="  ")
        metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        header = 'Epoch: [{}]'.format(epoch)
        iter_count = (epoch - 1) * self.step_per_epoch

        for bid, (images, targets) in enumerate(
                metric_logger.log_every(self.train_loader, self.config.log_period, len_iter=self.step_per_epoch,
                                        header=header)):
            images = list(image.to(self.device).detach() for image in images)
            targets = [{k: v.to(self.device).detach() for k, v in t.items()} for t in targets]

            if self.config.temperature != 0.0:
                std_global = max(self.config.temperature * (1.0 - float(iter_count + bid) / 20000.0),
                                 self.config.pixel_sigma)
            else:
                std_global = self.config.pixel_sigma

            self.optimizer.zero_grad()
            loss_dict = self.model(images, targets, std=std_global)
            if "PoE" in self.config.arch:
                info_gain = loss_dict.pop('info_gain')
                ig_label = ["iter" + str(i+1) for i in range(len(info_gain))]
            # losses = sum(loss for loss in loss_dict.values())

            # balancing KL terms
            if self.config.geco:
                # compute loss before update
                losses = sum(loss for loss in loss_dict.values())

                beta = self.geco.loss(loss_dict['query_nll'])
                if self.model.config.elbo_weights["kl_global"] != 0.:
                    self.model.config.elbo_weights["kl_global"] = beta
                if self.model.config.elbo_weights["kl_component"] != 0.:    
                    self.model.config.elbo_weights["kl_component"] = beta
                if self.model.config.elbo_weights["kl_latent"] != 0.:
                    self.model.config.elbo_weights["kl_latent"] = beta

            elif self.config.kl_anl:
                self.model.config.elbo_weights["kl_latent"] = self.kl_latent * np.clip(float((iter_count + bid - 50000) / 50000), 0.01, 1.)    # 0.1 during first 50k iter, then anneal linearly in next 50k iter
                self.model.config.elbo_weights["kl_global"] = self.kl_global * np.clip(float((iter_count + bid - 50000) / 50000), 0.01, 1.)
                self.model.config.elbo_weights["kl_component"] = self.kl_comp * np.clip(float((iter_count + bid - 50000) / 50000), 0.01, 1.)
                # compute loss with updated beta
                losses = sum(loss for loss in loss_dict.values())
            else:
                losses = sum(loss for loss in loss_dict.values())

            # --- back prop gradients ---
            losses.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_norm)
            self.optimizer.step()

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = utils.reduce_dict(loss_dict)

            # logging ins
            if self.writer is not None:
                # if bid % 10 == 0:    # logging every 10 iterations (or we can just increase steps per epoch)
                for kwd in loss_dict_reduced.keys():
                    self.writer.add_scalar('train/{}'.format(kwd), loss_dict_reduced[kwd], iter_count + bid)                
                # self.writer.add_scalars('train/Information Gain', dict(zip(ig_label, info_gain)), iter_count + bid)
                self.writer.add_scalar('std_annealing', std_global, iter_count + bid)
                if self.config.geco:
                    self.writer.add_scalar('beta', beta, iter_count + bid)
                else:
                    self.writer.add_scalar('beta', self.model.config.elbo_weights["kl_global"], iter_count + bid)

            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            loss_value = losses_reduced.item()

            # sanity check
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                print(loss_dict_reduced)
                sys.exit(1)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
                if self.writer is not None:
                    self.writer.add_scalar('lr/optimizer', self.lr_scheduler.get_lr()[0], iter_count + bid)

            metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])

            if bid == self.step_per_epoch:
                self.targets = targets    # save view points for generation
                break

    def _validate_epoch(self, epoch):
        """
        Validate model on validation data and save visual results for checking
        :return: a dict of model's output
        """
        self.model.eval()
        if epoch % self.config.show_period == 0:
            vis_epo_dir = os.path.join(self.vis_train_dir, 'epoch_{}'.format(epoch))
            if not os.path.exists(vis_epo_dir):
                os.mkdir(vis_epo_dir)
        else:
            vis_epo_dir = None
        with torch.no_grad():
            images, targets = next(iter(self.valid_loader))
            images = list(image.to(self.device) for image in images)
            targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
            val_outs = self.model.predict(images, targets, save_sample_to=vis_epo_dir)
        return val_outs

    def _sample(self, epoch):
        """
        generate novel view
        """
        self.model.eval()
        if epoch % self.config.show_period == 0:
            vis_epo_dir = os.path.join(self.vis_train_dir, 'epoch_{}'.format(epoch))
            if not os.path.exists(vis_epo_dir):
                os.mkdir(vis_epo_dir)
            with torch.no_grad():
                samples_outs = self.model.sample(vis_epo_dir, self.targets, std=1.0)
        else:
            samples_outs = None
        return samples_outs

    def _save_checkpoint(self, epoch):
        """
        Saving checkpoints

        :param epoch: current epoch number
        """
        d = {
             'epoch': epoch,
             'model': self.model.state_dict(),
             'optimizer': self.optimizer.state_dict()
        }
        filename = os.path.abspath(os.path.join(self.save_dir,
                                                'checkpoint-epoch{}.pth'.format(epoch)))
        torch.save(d, filename)

    def _resume_checkpoint(self, resume_path, optimizer=None):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        :param optimizer: Specify whether using a new optimizer if provided or stick with the previous
        """
        ckpt = torch.load(resume_path)
        self.model.load_state_dict(ckpt['model'], strict=True)
        self.start_epoch = ckpt['epoch']
        if optimizer is not None:
            optimizer.load_state_dict(ckpt['optimizer'])

class TwoStageTrainer(BaseTrainer):
    """ Trainer (single-gpu) """

    def __init__(self, model, loss, metrics, optimizer, config, train_data_loader,
                 valid_data_loader=None, lr_scheduler=None, step_per_epoch=None, device=None):
        super().__init__(model, loss, metrics, optimizer, config, device)
        self.config = config
        self.train_loader = train_data_loader
        self.epochs_prior = self.config.epochs_prior
        self.step_per_epoch_prior = self.config.step_per_epoch_prior
        self.kl_latent = config.elbo_weights["kl_latent"]    # save target value
        self.kl_global = config.elbo_weights["kl_global"]
        self.kl_comp = config.elbo_weights["kl_component"]

        self.prior_opt = torch.optim.Adam(model.comp_prior.parameters(), lr=5e-4)     # optimizer for component prior

        if step_per_epoch is None:
            # epoch-based training
            self.step_per_epoch = len(self.train_loader)
        elif step_per_epoch:
            # iteration-based training
            self.train_loader = inf_loop(self.train_loader)  # Reusable iterator
            self.step_per_epoch = step_per_epoch

        self.do_validation = valid_data_loader is not None
        if self.do_validation:
            self.valid_loader = valid_data_loader

        self.save_period = config.save_period
        self.lr_scheduler = lr_scheduler

        # setup GECO
        if self.config.geco:
            geco_goal = config.geco_cfg["goal"]
            geco_lr = config.geco_cfg["lr"]
            g_alpha = 0.99
            g_init = config.geco_cfg["min"] 
            g_min = config.geco_cfg["min"]
            g_max = config.geco_cfg["max"]
            g_speedup = config.geco_cfg["speedup"]
            self.geco = GECO(self.device, geco_goal, geco_lr, g_alpha, g_init, g_min, g_max, g_speedup)

        if self.config.resume_epoch is not None:
            self.start_epoch = self.config.resume_epoch + 1
            if self.lr_scheduler is not None:
                self.lr_scheduler.last_epoch = self.config.resume_epoch * self.step_per_epoch
                self.optimizer.param_groups[0]['lr'] = self.lr_scheduler.get_lr()[0]

        # self.save_dir has been declared in parenthese
        assert config.vis_train_dir is not None
        self.vis_train_dir = config.vis_train_dir
        assert config.generated_dir is not None
        self.generated_dir = config.generated_dir

    def train(self):
        # self.start_time = time.time()
        if not self.config.prior_only:
            print('\n================== Start training ===================')
            assert (self.epochs + 1) > self.start_epoch
            for epoch in range(self.start_epoch, self.epochs + 1):
                print("Epoch {}".format(epoch))
                self._train_epoch(epoch)

                if epoch % self.config.val_period == 0:
                    valid_outs = self._validate_epoch(epoch)
                    # samples = self._sample(epoch)    # save generation results

                if epoch % self.save_period == 0:
                    self._save_checkpoint(epoch)
        
        print('\n================== Start prior training ===================')
        for epoch in range(1, self.epochs_prior + 1):
            print("Epoch {}".format(epoch + self.epochs))
            self._train_epoch_prior(epoch + self.epochs)
            if epoch % 5 == 0:
                samples = self._sample(epoch + self.epochs)    # save generation results
                self._save_checkpoint(epoch + self.epochs, name="prior-fitting")
        print('models have been saved to {}'.format(self.save_dir))
        print('================= Training finished =================\n')

    # def train_prior(self):
    #     assert (self.epochs + 1) > self.start_epoch
    #     for epoch in range(1, self.epochs_prior + 1):
    #         print("Epoch(prior) {}".format(epoch))
    #         self._train_epoch_prior(epoch)

    #         if epoch % self.config.val_period == 0:
    #             # valid_outs = self._validate_epoch(epoch)
    #             samples = self._sample(epoch)    # save generation results
    #         if epoch % 5 == 0:
    #             self._save_checkpoint(epoch)

    #     print('models have been saved to {}'.format(self.save_dir))
    #     print('================= Training finished =================\n')

    def _train_epoch(self, epoch):
        """
        Training logic for 1st stage (main model, except component prior)
        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.
        """
        self.model.train()
        metric_logger = utils.MetricLogger(delimiter="  ")
        metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        header = 'Epoch: [{}]'.format(epoch)
        iter_count = (epoch - 1) * self.step_per_epoch

        for bid, (images, targets) in enumerate(
                metric_logger.log_every(self.train_loader, self.config.log_period, len_iter=self.step_per_epoch,
                                        header=header)):
            images = list(image.to(self.device).detach() for image in images)
            targets = [{k: v.to(self.device).detach() for k, v in t.items()} for t in targets]

            if self.config.temperature != 0.0:
                std_global = max(self.config.temperature * (1.0 - float(iter_count + bid) / 20000.0),
                                 self.config.pixel_sigma)
            else:
                std_global = self.config.pixel_sigma

            self.optimizer.zero_grad()
            loss_dict = self.model(images, targets, std=std_global, prior_only=False)  # do not update comp prior
            info_gain = loss_dict.pop('info_gain')
            ig_label = ["iter" + str(i+1) for i in range(len(info_gain))]
            # losses = sum(loss for loss in loss_dict.values())

            # balancing KL terms
            if self.config.geco:
                # compute loss before update
                # losses = sum(loss for loss in loss_dict.values())
                
                # (1st stage) Only Global KL and NLLs
                losses = loss_dict['neg_elbo'] + loss_dict["query_nll"] + loss_dict["global_kl"]

                beta = self.geco.loss(loss_dict['query_nll'])
                if self.model.config.elbo_weights["kl_global"] != 0.:
                    self.model.config.elbo_weights["kl_global"] = beta
                if self.model.config.elbo_weights["kl_component"] != 0.:    
                    self.model.config.elbo_weights["kl_component"] = beta
                if self.model.config.elbo_weights["kl_latent"] != 0.:
                    self.model.config.elbo_weights["kl_latent"] = beta

            elif self.config.kl_anl:
                self.model.config.elbo_weights["kl_latent"] = self.kl_latent * np.clip(float((iter_count + bid - 50000) / 50000), 0.01, 1.)    # 0.1 during first 50k iter, then anneal linearly in next 50k iter
                self.model.config.elbo_weights["kl_global"] = self.kl_global * np.clip(float((iter_count + bid - 50000) / 50000), 0.01, 1.)
                self.model.config.elbo_weights["kl_component"] = self.kl_comp * np.clip(float((iter_count + bid - 50000) / 50000), 0.01, 1.)
                # compute loss with updated beta
                # losses = sum(loss for loss in loss_dict.values())
                losses = loss_dict['neg_elbo'] + loss_dict["query_nll"] + loss_dict["global_kl"]

            # --- back prop gradients ---
            losses.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_norm)
            self.optimizer.step()

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = utils.reduce_dict(loss_dict)

            # logging ins
            if self.writer is not None:
                # if bid % 10 == 0:    # logging every 10 iterations (or we can just increase steps per epoch)
                for kwd in loss_dict_reduced.keys():
                    self.writer.add_scalar('train/{}'.format(kwd), loss_dict_reduced[kwd], iter_count + bid)                
                # self.writer.add_scalars('train/Information Gain', dict(zip(ig_label, info_gain)), iter_count + bid)
                self.writer.add_scalar('std_annealing', std_global, iter_count + bid)
                if self.config.geco:
                    self.writer.add_scalar('beta', beta, iter_count + bid)
                else:
                    self.writer.add_scalar('beta', self.model.config.elbo_weights["kl_global"], iter_count + bid)

            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            loss_value = losses_reduced.item()

            # sanity check
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                print(loss_dict_reduced)
                sys.exit(1)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
                if self.writer is not None:
                    self.writer.add_scalar('lr/optimizer', self.lr_scheduler.get_lr()[0], iter_count + bid)

            metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])

            if bid == self.step_per_epoch:
                self.targets = targets    # save view points for generation
                break

    def _train_epoch_prior(self, epoch):
        """
        Training logic for 2nd stage (prior fitting)
        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.
        """

        # Fix all parameters except component prior
        # self.model.eval()
        # self.model.comp_prior.train()
        # for param in self.model.parameters():
        #     param.requires_grad = False
        # for param in self.model.comp_prior.parameters():
        #     param.requires_grad = True

        self.model.train()
        metric_logger = utils.MetricLogger(delimiter="  ")
        # metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        header = 'Epoch: [{}]'.format(epoch)
        step_per_epoch = self.step_per_epoch_prior
        iter_count = (epoch - 1) * step_per_epoch

        for bid, (images, targets) in enumerate(
                metric_logger.log_every(self.train_loader, self.config.log_period, len_iter=step_per_epoch,
                                        header=header)):
            images = list(image.to(self.device).detach() for image in images)
            targets = [{k: v.to(self.device).detach() for k, v in t.items()} for t in targets]

            if self.config.temperature != 0.0:
                std_global = max(self.config.temperature * (1.0 - float(iter_count + bid) / 20000.0),
                                 self.config.pixel_sigma)
            else:
                std_global = self.config.pixel_sigma

            self.optimizer.zero_grad()
            loss_dict = self.model(images, targets, std=std_global, prior_only=True)  # update only comp prior
            info_gain = loss_dict.pop('info_gain')
            # ig_label = ["iter" + str(i+1) for i in range(len(info_gain))]
            # losses = sum(loss for loss in loss_dict.values())

            losses = loss_dict["comp_kl"]

            # --- back prop gradients ---
            losses.backward()
            torch.nn.utils.clip_grad_norm_(self.model.comp_prior.parameters(), 5.0)
            self.prior_opt.step()

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = utils.reduce_dict(loss_dict)

            # logging ins
            if self.writer is not None:
                # if bid % 10 == 0:    # logging every 10 iterations (or we can just increase steps per epoch)
                for kwd in loss_dict_reduced.keys():
                    if kwd == "comp_kl":
                        self.writer.add_scalar('train/{}'.format(kwd), loss_dict_reduced[kwd], iter_count + bid)

            # losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            # loss_value = losses_reduced.item()
            losses_reduced = loss_dict["comp_kl"]
            loss_value = losses_reduced.item()

            # sanity check
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                print(loss_dict_reduced)
                sys.exit(1)

            # if self.lr_scheduler is not None:
            #     self.lr_scheduler.step()
            #     if self.writer is not None:
            #         self.writer.add_scalar('lr/optimizer', self.lr_scheduler.get_lr()[0], iter_count + bid)

            metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            # metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])

            if bid == step_per_epoch:
                self.targets = targets    # save view points for generation
                break

    def _validate_epoch(self, epoch):
        """
        Validate model on validation data and save visual results for checking
        :return: a dict of model's output
        """
        self.model.eval()
        if epoch % self.config.show_period == 0:
            vis_epo_dir = os.path.join(self.vis_train_dir, 'epoch_{}'.format(epoch))
            if not os.path.exists(vis_epo_dir):
                os.mkdir(vis_epo_dir)
        else:
            vis_epo_dir = None
        with torch.no_grad():
            images, targets = next(iter(self.valid_loader))
            images = list(image.to(self.device) for image in images)
            targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
            val_outs = self.model.predict(images, targets, save_sample_to=vis_epo_dir)
        return val_outs

    def _sample(self, epoch):
        """
        generate novel view
        """
        self.model.eval()
        vis_epo_dir = os.path.join(self.vis_train_dir, 'epoch_prior_{}'.format(epoch))
        if not os.path.exists(vis_epo_dir):
            os.mkdir(vis_epo_dir)
        with torch.no_grad():
            samples_outs = self.model.sample(vis_epo_dir, self.targets, std=1.0)
        return samples_outs

    def _save_checkpoint(self, epoch, name=None):
        """
        Saving checkpoints

        :param epoch: current epoch number
        """
        d = {
             'epoch': epoch,
             'model': self.model.state_dict(),
             'optimizer': self.optimizer.state_dict()
        }
        if name is not None:
            filename = os.path.abspath(os.path.join(self.save_dir,
                                                    'checkpoint_' + name + '-epoch{}.pth'.format(epoch)))
        else:    
            filename = os.path.abspath(os.path.join(self.save_dir,
                                                    'checkpoint-epoch{}.pth'.format(epoch)))
        torch.save(d, filename)

    def _resume_checkpoint(self, resume_path, optimizer=None):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        :param optimizer: Specify whether using a new optimizer if provided or stick with the previous
        """
        ckpt = torch.load(resume_path)
        self.model.load_state_dict(ckpt['model'], strict=True)
        self.start_epoch = ckpt['epoch']
        if optimizer is not None:
            optimizer.load_state_dict(ckpt['optimizer'])