import torch
import lightning as L
import torch.nn as nn
import torch.nn.functional as F

import torch._dynamo

from einops import rearrange
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.distributions import Normal, Independent
from torch.distributions.kl import kl_divergence as KLD
from torch.nn.utils.stateless import _reparametrize_module
from contextlib import contextmanager

from utils.module.distribution import PowerSpherical, HypersphericalUniform
from utils.module.cnn import Encoder, Decoder


class VIAE(L.LightningModule):
    def __init__(self, 
                 input_size=32, in_chans=3, latent_dim=256, 
                 ch=[64, 64], ch_mult=[[1, 2, 2, 4, 4], [1, 2, 2, 4, 8]], 
                 num_res_blocks=[3, 3], attn_type='multi_heads',
                 attn_resolutions=[[8, 4, 2], [8, 4, 2]], 
                 attn_heads=[[2, 4, 4, 4], [2, 4, 8, 8]], 
                 dropout=0.0,

                 init_sigma=1e-3, max_sigma=2.,

                 kl_weight=1., 
                 use_idempotent=True, 
                 idem_alpha=0.05, idem_beta=5, 

                 lr=1e-4, 
                 *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.latent_res = (self.hparams.input_size 
                           // 
                           (2 ** (len(self.hparams.ch_mult[0]) - 1)))
        
        self.encoder = Encoder(
            ch=self.hparams.ch[0], 
            out_ch=self.hparams.in_chans, 
            ch_mult=self.hparams.ch_mult[0], 
            num_res_blocks=self.hparams.num_res_blocks[0],
            attn_resolutions=self.hparams.attn_resolutions[0], 
            attn_heads=self.hparams.attn_heads[0],
            dropout=self.hparams.dropout, 
            in_channels=self.hparams.in_chans,
            resolution=self.hparams.input_size, 
            z_channels=self.hparams.latent_dim,
            attn_type=self.hparams.attn_type,
            double_out=True,
        )
        
        self.decoder = Decoder(
            ch=self.hparams.ch[1], 
            out_ch=self.hparams.in_chans, 
            ch_mult=self.hparams.ch_mult[1], 
            num_res_blocks=self.hparams.num_res_blocks[1],
            attn_resolutions=self.hparams.attn_resolutions[1], 
            attn_heads=self.hparams.attn_heads[1],
            dropout=self.hparams.dropout, 
            in_channels=self.hparams.in_chans,
            resolution=self.hparams.input_size, 
            z_channels=self.hparams.latent_dim,
            attn_type=self.hparams.attn_type,
            double_out=True,
        )

        self.dims = self.hparams.latent_dim * self.latent_res * self.latent_res
        self.mu_norm = nn.LayerNorm(self.dims, 
                                    bias=False, 
                                    elementwise_affine=False)

        if self.hparams.max_sigma:
            self.inc_sigma = (self.hparams.max_sigma - self.hparams.init_sigma)
        else:
            self.inc_sigma = None

        self.register_buffer('tlr', torch.tensor(1.0))
        self.register_buffer('basic_kappa', torch.tensor(1024.0))
    
    @contextmanager
    def star_mode(self):
        detached_params = {
            k: v.detach() 
            for k, v in self.named_parameters()
        }
        buffers = {
            k: v 
            for k, v in self.named_buffers()
        }
        params_and_buffers = {**detached_params, **buffers}
        with _reparametrize_module(self, params_and_buffers):
            yield self

    def get_dist(self, params, sample=True):
        mu, log_kappa = params.chunk(2, dim=1)
        kappa = F.softplus(log_kappa).mean(dim=[1, 2, 3]) + self.basic_kappa
        mu = rearrange(mu, 'b c h w -> b (c h w)')
        mu = self.mu_norm(mu)
        mu = F.normalize(mu, p=2, dim=-1)
        dist = PowerSpherical(mu, kappa)
        
        if sample:
            z = dist.rsample()
            z = rearrange(z, 'b (c h w) -> b c h w', h=self.latent_res, w=self.latent_res)
            return dist, z
        return dist

    def p_theta(self, z, noise=None, decoder=None, 
                return_dist=False, exception=False):
        if decoder is None:
            decoder = self.decoder

        params = decoder(z)
        mu, sigma = params.chunk(2, dim=1)
        mu = torch.tanh(mu)
        if self.inc_sigma:
            sigma = torch.sigmoid(sigma) * self.inc_sigma + self.hparams.init_sigma
        else:
            sigma = F.softplus(sigma) + self.hparams.init_sigma
        
        if return_dist:
            dist = Independent(Normal(mu, sigma), 3)
            return dist
        else:
            if exception:
                return mu
            else:
                if noise is not None:
                    out = mu + sigma * noise
                else:
                    noise = torch.randn_like(mu)
                    out = mu + sigma * noise
                return out

    def get_prior(self,):
        prior = HypersphericalUniform(dim=self.dims, device=self.device)
        return prior

    def forward(self, x):
        param_x = self.encoder(x)
        dist_zx, zx = self.get_dist(param_x)
        
        outputs = {'param_x': param_x, 'dist_zx': dist_zx}

        dist_x = self.p_theta(zx, return_dist=True)
        outputs.update({'dist_x': dist_x})

        if not self.hparams.use_idempotent:
            return outputs
        
        param_xd = param_x.detach()
        outputs.update({'param_xd': param_xd})

        if self.hparams.use_idempotent:
            param_u = torch.randn_like(param_xd)
            idem_outputs = self._calculate_idem_terms(param_xd, param_u)
            outputs.update(idem_outputs)
            
        return outputs

    def _calculate_idem_terms(self, param_z, param_u):
        _, z = self.get_dist(param_z)
        x = self.p_theta(z, decoder=self.decoder)
        param_f_z = self.encoder(x)

        _, u = self.get_dist(param_u)
        
        noise = torch.randn_like(x)
        param_f_u = self.encoder(self.p_theta(u, noise, self.decoder))
        param_fu = param_f_u.detach()
        
        _, f_u = self.get_dist(param_f_u)
        fu = f_u.detach()
        param_f_fu = self.encoder(self.p_theta(fu, noise, self.decoder))

        with self.star_mode():
            param_ff_u = self.encoder(self.p_theta(f_u, noise, self.decoder))
        
        return {
            'param_f_x': param_f_z,
            'param_fu': param_fu, 
            'param_f_fu': param_f_fu, 
            'param_f_u': param_f_u,
            'param_ff_u': param_ff_u,
        }

    def _calculate_losses(self, fwd_outputs, x, stage):
        dist_x = fwd_outputs['dist_x']
        nll_loss = -dist_x.log_prob(x).mean()

        dist_zx = fwd_outputs['dist_zx']
        prior = self.get_prior()
        kl_loss = KLD(dist_zx, prior).mean()
        losses = {'nll': nll_loss, 'kl': kl_loss}
        total_loss = nll_loss + kl_loss * self.hparams.kl_weight

        if self.hparams.use_idempotent:
            idem_loss, idem_losses = self._compute_idem_loss(fwd_outputs)
            total_loss += self.hparams.idem_alpha * idem_loss
            losses.update(idem_losses)
        
        losses['loss'] = total_loss
        return losses

    def _compute_idem_loss(self, outputs):

        loss_fn = self._distillation

        param_x = outputs['param_xd']

        param_f_x = outputs['param_f_x']
        param_fu = outputs['param_fu']
        param_f_fu = outputs['param_f_fu']
        param_f_u = outputs['param_f_u']
        param_ff_u = outputs['param_ff_u']

        kl_stable = loss_fn(param_f_x, param_x)
        kl_tight = loss_fn(param_f_fu, param_fu)
        kl_idem = loss_fn(param_ff_u, param_f_u)

        kl_stable = loss_fn(param_f_x, param_x)
        kl_tight = loss_fn(param_f_fu, param_fu)
        kl_idem = loss_fn(param_ff_u, param_f_u)

        beta = self.hparams.idem_beta 

        ratio = kl_stable.mean().detach() + 1e-5
        far = ratio / (kl_tight.mean().detach() + 1e-5)
        gate = torch.exp(-far).clamp(0., 1.) * beta * self.tlr

        kl_s = kl_stable / ratio
        kl_t = kl_tight / ratio
        kl_i = kl_idem / ratio

        kl_s = torch.tanh(kl_s).mean()
        kl_t = torch.tanh(kl_t).mean()
        kl_i = (torch.tanh(kl_i) + gate * torch.log(1 + kl_i)).mean()
        total_idem_loss = (kl_s - kl_t + kl_i) * ratio
        
        loss_dict = {
            'stable': kl_stable.mean(),
            'tight': kl_tight.mean(),
            'idem': kl_idem.mean(),
            'total_idem': total_idem_loss,
        }
        return total_idem_loss, loss_dict
    
    def _distillation(self, p1, p2):
        prior = self.get_prior()
        dist_z1, z1 = self.get_dist(p1)
        _, z2 = self.get_dist(p2)

        with self.star_mode():
            dist_x1 = self.p_theta(z1, return_dist=True)
            dist_x2 = self.p_theta(z2, return_dist=True)

        loss = (KLD(dist_z1, prior) + KLD(dist_x2, dist_x1)).mean()

        return loss
    
    def _common_step(self, batch, stage):
        x = batch
        forward_outputs = self.forward(x)
        losses = self._calculate_losses(forward_outputs, x, stage)
        
        return losses

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, 'train')
        
    def on_train_batch_end(self, outputs, batch, batch_idx):
        self._log_metrics(outputs, 'train')

    @torch._dynamo.disable
    def val_step(self, batch):
        return self._common_step(batch, 'val')

    @torch._dynamo.disable
    def val_batch_end(self, outputs):
        self._log_metrics(outputs, 'val')

    @torch._dynamo.disable
    def val_epoch_end(self):
        self.tlr *= 0.999

    def validation_step(self, batch, batch_idx):
        return self.val_step(batch)
    
    def on_validation_batch_end(self, outputs, batch, batch_idx):
        self.val_batch_end(outputs)
        
    def on_validation_epoch_end(self):
        self.val_epoch_end()

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr, 
                          betas=(0.5, 0.999), weight_decay=5e-4)

        total_steps = 800000
        warmup_steps = 10000

        main_scheduler = CosineAnnealingLR(
            optimizer,
            T_max=total_steps - warmup_steps,
            eta_min=1e-7
        )

        warmup_scheduler = LinearLR(
            optimizer,
            start_factor=1e-3,
            total_iters=warmup_steps
        )
        
        scheduler = SequentialLR(optimizer, 
                                 schedulers=[warmup_scheduler, main_scheduler], 
                                 milestones=[warmup_steps])

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1
            },
        }

    def _log_metrics(self, losses, stage):
        metrics_to_log = ['loss', 'nll', 'kl']
        if self.hparams.use_idempotent:
            metrics_to_log.extend(['stable', 'idem'])

        metrics_for_prog_bar = ['nll', 'kl']
        if self.hparams.use_idempotent:
            metrics_for_prog_bar.extend(['stable', 'idem'])

        log_dict = {
            f'{stage}/{name}': value
            for name, value in losses.items()
            if name in metrics_to_log
        }
        
        bar_dict = {
            f'{stage[0]}{name[:2]}': value
            for name, value in losses.items()
            if name in metrics_for_prog_bar
        }
        
        log_on_step = (stage == 'train')
        self.log_dict(log_dict, on_step=log_on_step, on_epoch=not log_on_step, prog_bar=False)
        self.log_dict(bar_dict, on_step=log_on_step, on_epoch=not log_on_step, prog_bar=True, logger=False)

    @torch.no_grad()
    def gen_data(self, num_samples=1, steps=1, exception=False):
        prior = self.get_prior()
        z = prior.sample((num_samples, ))
        z = rearrange(z, 'b (c h w) -> b c h w', 
                       h=self.latent_res,
                       w=self.latent_res)

        if self.hparams.use_idempotent and steps > 0:
            for _ in range(steps):
                params = self.encoder(self.p_theta(z))
                _, z = self.get_dist(params)
        
        return self.p_theta(z, exception=exception).clamp(-1, 1), z
