import torch
import torch.nn.functional as F
from einops import rearrange, reduce

from models import OmlVae
from models.model import Output
from utils import binarize, kl_div, nll_to_bpd, reparameterize


class MamlVae(OmlVae):
    def forward(self, train_x, train_y, test_x, test_y, summarize, meta_split):
        batch, train_num = train_x.shape[:2]
        batch, test_num = test_x.shape[:2]
        is_meta_training = meta_split == 'train'

        train_x = binarize(train_x)
        test_x = binarize(test_x)

        inner_lr = self.log_inner_lr.exp()
        self.kl_weight += 1. / self.config['kl_warmup']
        self.kl_weight = min(self.kl_weight, 1.0)
        kl_weight = self.kl_weight if meta_split == 'train' else 1.0

        with torch.enable_grad():
            self.reset_fast_params(batch)

            # Inner loop
            for i in range(train_num):
                # Sequentially forward training data
                x_i = train_x[:, i:i + 1]
                x_enc_i = self.x_encoder(2 * x_i - 1)
                latent_mean_log_var = self.enc_mlp(x_enc_i)
                latent_mean, latent_log_var = torch.unbind(latent_mean_log_var, dim=-2)
                latent = reparameterize(latent_mean, latent_log_var)
                logit = self.decode(latent)
                recon_loss = F.binary_cross_entropy_with_logits(logit, train_x[:, i:i + 1], reduction='none').sum()
                kl_loss = kl_div(latent_mean, latent_log_var).sum()
                loss = nll_to_bpd(recon_loss + kl_loss * kl_weight, self.x_dims)
                self.inner_update(loss, inner_lr, is_meta_training=is_meta_training)

        # Forward test data
        test_x_enc = self.x_encoder(2 * test_x - 1)
        latent_mean_log_var = self.enc_mlp(test_x_enc)
        latent_mean, latent_log_var = torch.unbind(latent_mean_log_var, dim=-2)
        kl_loss = kl_div(latent_mean, latent_log_var)

        latent_samples = self.config['eval_latent_samples'] if meta_split == 'test' else 1
        recon_loss = torch.zeros_like(kl_loss)
        for _ in range(latent_samples):
            latent = reparameterize(latent_mean, latent_log_var)
            logit = self.decode(latent)
            bce = F.binary_cross_entropy_with_logits(logit, test_x, reduction='none')
            bce = reduce(bce, 'b l c h w -> b l', 'sum')
            recon_loss = recon_loss + bce
        recon_loss = recon_loss / latent_samples

        meta_loss = nll_to_bpd(recon_loss + kl_loss * kl_weight, self.x_dims)
        meta_loss = reduce(meta_loss, 'b l -> b', 'mean')

        if self.reptile and is_meta_training:
            self.reptile_update(self.config['reptile_lr'])

        output = Output()
        output[f'loss/meta_{meta_split}'] = meta_loss
        if not summarize:
            return output

        if meta_split == 'train':
            output['lr_inner'] = rearrange(inner_lr.detach(), '-> 1')

        output[f'loss/kl/meta_{meta_split}'] = reduce(kl_loss, 'b l -> b', 'mean')
        output[f'loss/recon/meta_{meta_split}'] = reduce(recon_loss, 'b l -> b', 'mean')
        output.add_image_comparison_summary(test_x, torch.sigmoid(logit), key=f'recon/meta_{meta_split}')
        return output

    def decode(self, latent):
        dec_in = self.dec_mlp(latent)
        return self.decoder(dec_in)
