import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl

import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))

from models.setvae.networks import SetVAE
from models.setvae.criterion import ChamferCriterion, EMDCriterion, CombinedCriterion, ApproxEMDCriterion
import matplotlib.pyplot as plt


class SetVAEModule(pl.LightningModule):

    def __init__(self, model_args, opt_args, max_outputs):
        super().__init__()

        self.save_hyperparameters('model_args', 'opt_args', logger=False)
        self.model = SetVAE(model_args, max_outputs)
        parameters = self.model.parameters()

        n_parameters = sum(p.numel() for p in parameters if p.requires_grad)
        print(f'number of params: {n_parameters}')
        try:
            n_gen_parameters = sum(p.numel() for p in self.model.init_set.parameters() if p.requires_grad) + \
                               sum(p.numel() for p in self.model.pre_decoder.parameters() if p.requires_grad) + \
                               sum(p.numel() for p in self.model.decoder.parameters() if p.requires_grad) + \
                               sum(p.numel() for p in self.model.post_decoder.parameters() if p.requires_grad) + \
                               sum(p.numel() for p in self.model.output.parameters() if p.requires_grad)
            print(f'number of generator params: {n_gen_parameters}')
        except AttributeError:
            pass

        print(self.hparams)
        self.criterion = self.get_criterion(self.hparams.opt_args.matcher)

    def get_criterion(self, matcher):

        if matcher == 'hungarian':
            criterion = EMDCriterion()
        elif matcher == 'chamfer':
            criterion = ChamferCriterion()
        elif matcher == 'approxEMD':
            criterion = ApproxEMDCriterion(ckpt_path=self.hparams.opt_args.matcher_ckpt_path)
        else:
            assert matcher == 'all'
            criterion = CombinedCriterion()

        return criterion

    def configure_optimizers(self):

        if self.hparams.opt_args.type == 'adam':
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.opt_args.lr,
                                         betas=(self.hparams.opt_args.beta1, self.hparams.opt_args.beta2),
                                         weight_decay=self.hparams.opt_args.weight_decay)
        elif self.hparams.opt_args.type == 'sgd':
            optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.opt_args.lr,
                                        momentum=self.hparams.opt_args.momentum)
        else:
            assert 0, "args.optimizer should be either 'adam' or 'sgd'"

        if self.hparams.opt_args.sch_type == 'exponential':
            assert not (self.hparams.opt_args.warmup_epochs > 0)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, self.hparams.opt_args.exp_decay)
        elif self.hparams.opt_args.sch_type == 'step':
            assert not (self.hparams.opt_args.warmup_epochs > 0)
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.hparams.opt_args.epochs // 2,
                                                        gamma=0.1)
        elif self.hparams.opt_args.sch_type == 'linear':
            def lambda_rule(ep):
                lr_w = min(1., ep / self.hparams.opt_args.warmup_epochs) if (
                            self.hparams.opt_args.warmup_epochs > 0) else 1.
                lr_l = 1.0 - max(0, ep - 0.5 * self.hparams.opt_args.epochs) / float(
                    0.5 * self.hparams.opt_args.epochs)
                return lr_l * lr_w

            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)

        elif self.hparams.opt_args.sch_type == 'cosine':
            assert not (self.hparams.opt_args.warmup_epochs > 0)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.hparams.opt_args.epochs)
        else:
            # Fake SCHEDULER
            def lambda_rule(ep):
                return 1.0

            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)

        return [optimizer], [scheduler]

    # def denormalize(self, data, output):

    #     gt = data['set']
    #     recon = output['set']

    #     try:
    #         m, s = data['mean'].float(), data['std'].float()
    #         # m = m.to(gt.device)
    #         # s = s.to(gt.device)
    #     except (TypeError, AttributeError) as e:
    #         m, s = float(data['mean']), float(data['std'])

    #     if self.hparams.args.dataset.standardize_per_shape:
    #         offset = data['offset']
    #         gt = gt + offset  # .to(gt.device)
    #         recon = recon + offset  # .to(recon.device)

    #     gt = gt * s + m
    #     recon = recon * s + m

    #     return gt, recon

    def compute_loss(self, data, output):

        gt, gt_mask = data['set'].clone().detach(), data['set_mask'].clone().detach()
        recon, recon_mask = output['set'], output['set_mask']

        # if self.hparams.args.dataset.denormalized_loss:
        #     gt, recon = self.denormalize(data, output)

        loss_intermediate = self.criterion(recon, recon_mask, gt, gt_mask, reduced=False)
        l2_loss = loss_intermediate.mean()
        loss_dict = self.compute_kl_loss(output['kls'])

        loss = loss_dict['beta'] * loss_dict['kl'] + l2_loss
        loss_dict.update({'loss': loss, 'l2': l2_loss, 'dists': loss_intermediate})

        return loss_dict

    def compute_kl_loss(self, kls):
        kl_loss = torch.stack(kls, dim=1).sum(dim=1).mean()  # [B, Dz] -> scalar
        if self.hparams.opt_args.kl_warmup_epochs > 0:
            assert self.current_epoch is not None
            beta = self.hparams.opt_args.beta * min(1,
                                                          self.current_epoch / self.hparams.opt_args.kl_warmup_epochs)
        else:
            beta = self.hparams.opt_args.beta

        # loss = beta * kl_loss + recon_loss
        topdown_kl = [kl.detach().mean(dim=0) / float(scale * self.hparams.model_args.z_dim) for scale, kl in
                      zip(self.hparams.model_args.z_scales, kls)]

        return {'kl': kl_loss,
                'topdown_kl': topdown_kl,
                'beta': beta}

    def forward(self, set, set_mask):
        return self.model(set, set_mask)

    def training_step(self, batch, batch_idx):

        gt, gt_mask = batch['set'], batch['set_mask']

        output = self(gt, gt_mask)
        recon = output['set']
        # recon.requires_grad = True
        recon.retain_grad()
        losses = self.compute_loss(batch, output)

        loss, kl_loss, l2_loss, dists, topdown_kl, beta = losses['loss'], losses['kl'], losses['l2'], losses['dists'], losses['topdown_kl'], \
                                                   losses['beta']
        # self.model.optimizer.zero_grad()
        # self.model.backward(loss)

        # compute gradient norm
        '''
        if self.global_rank == 0:
            total_norm = 0.
            for p in self.model.parameters():
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)

            self.log('grad_norm', total_norm, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True)
        '''

        self.log('train/loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
        self.log('train/kl_loss', kl_loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
        self.log('train/recon_loss', l2_loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
        self.log('train/beta', beta, on_step=False, on_epoch=True, sync_dist=False, prog_bar=False)

        # fig = plt.figure()
        # ax = fig.add_subplot(1, 1, 1)
        # ax.plot([kl_per_dim.detach().item() for kl_per_dim in topdown_kl])
        # TODO self.logger.experiment.add_figure('train top-down kl', fig, self.global_step, close=True)

        return {"source": gt, "target": recon, "loss":loss, "dists": dists}

    def on_validation_epoch_start(self):
        pass

    def validation_step(self, batch, batch_idx):

        gt, gt_mask = batch['set'], batch['set_mask']

        output = self(gt, gt_mask)

        recon, recon_mask = output['set'], output['set_mask']

        losses = self.compute_loss(batch, output)

        loss, kl_loss, l2_loss, dists, topdown_kl, beta = losses['loss'], losses['kl'], losses['l2'], losses['dists'], losses['topdown_kl'], \
                                                   losses['beta']

        self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=False)
        self.log('val/kl_loss', kl_loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=False)
        self.log('val/recon_loss', l2_loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=False)

        return {"source": gt, "target": recon, "dists": dists}

    def test_step(self, batch, batch_idx):
        gt, gt_mask = batch['set'], batch['set_mask']

        output = self(gt, gt_mask)

        recon, recon_mask = output['set'], output['set_mask']

        losses = self.compute_loss(batch, output)

        loss, kl_loss, l2_loss, dists, topdown_kl, beta = losses['loss'], losses['kl'], losses['l2'], losses['dists'], losses['topdown_kl'], \
                                                   losses['beta']

        self.log('test/loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=False)
        self.log('test/kl_loss', kl_loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=False)
        self.log('test/recon_loss', l2_loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=False)

        return {"source": gt, "target": recon, "dists": dists}

    def generate_sample(self):
        pass

    def on_validation_epoch_end(self):
        pass
