import time
from types import SimpleNamespace

import numpy as np
import torch

from code.realnvp_v2 import RealNVP
from code.composed_model import CondModel
from code.exp_utils import ForwardOp, normal_logp, EMA


def train_cond_model(*,
                     base_model: RealNVP,
                     cond_model: CondModel,
                     x_star: torch.Tensor,
                     forward_op: ForwardOp,
                     device: torch.device,
                     sigma: float,
                     batch_size: int,
                     learning_rate: float,
                     num_steps: int,
                     temps,
                     use_ema: bool,
                     sample_steps,
                     x_star_obs: torch.Tensor=None,
                     ):
    assert int(x_star is None) + int(x_star_obs is None) == 1

    if x_star is None:
        x_star = torch.zeros(*base_model.image_shape, device=device)

    if x_star.ndim == 3:
        x_star.unsqueeze_(0)
    assert x_star.ndim == 4

    if use_ema:
        ema = EMA(cond_model, decay=0.95)

    # Get target observation
    if x_star_obs is None:
        x_star_obs, _ = forward_op.observe(x_star)

    # Create model and optimizer
    optimizer = torch.optim.Adam(cond_model.parameters(), lr=learning_rate)

    # Train
    stats = SimpleNamespace(
        loss                    = [],
        loss_kl                 = [],
        loss_rec                = [],

        steps_per_sec           = [],
        total_time              = [],

        x_star                  = x_star.cpu(),
    )

    steps_to_samples = {}
    start_time = time.time()
    N = batch_size
    for step in range(1, num_steps + 1):
        zhat, logp_zhat = cond_model.sample(N, device=device)
        x, _ = base_model(zhat, inverse=True)
        x = x.clamp(0, 1)
        assert zhat.shape == x.shape == (N, *base_model.image_shape)

        # Loss: KL term
        normal_zhat = normal_logp(zhat).view(len(zhat), -1).sum(dim=1)
        assert logp_zhat.shape == normal_zhat.shape == (N,)
        loss_kl = logp_zhat - normal_zhat

        # Loss: Reconstruction
        # TODO: support DistFn
        x_obs, _ = forward_op.observe(x)
        assert x_obs.shape[1:] == x_star_obs.shape[1:]
        loss_rec = (x_obs - x_star_obs).reshape(N, -1).pow(2).sum(dim=1)

        assert loss_kl.shape == loss_rec.shape == (N,)
        loss_kl = loss_kl.mean()
        loss_rec = loss_rec.mean()
        loss = loss_kl + loss_rec / (2.0 * sigma**2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if use_ema:
            ema(cond_model, num_updates=step)

        # Monitoring
        total_time = time.time() - start_time
        stats.loss.append(loss.item())
        stats.loss_kl.append(loss_kl.item())
        stats.loss_rec.append(loss_rec.item())
        stats.steps_per_sec.append(step / total_time)
        stats.total_time.append(total_time)

        print(f'\rstep {step:05d}/{num_steps:05d} '
              f'loss {np.mean(stats.loss[-10:]):.3f} ',
              f'loss_kl {np.mean(stats.loss_kl[-10:]):.3f} ',
              f'loss_rec {np.mean(stats.loss_rec[-10:]):.3f} ',
              f'time {stats.total_time[-1]:.2f} sec ',
              f'steps/sec {stats.steps_per_sec[-1]:.2f} ', end='')

        # Get samples
        if step in sample_steps or step == num_steps:
            cond_model.eval()
            if use_ema:
                ema.assign(cond_model)
            steps_to_samples[step] = {}
            with torch.no_grad():
                for temp in temps:
                    batch = []
                    for _ in range(8):
                        zhat = cond_model.sample(4, device=device, temp=temp)[0]
                        x = base_model(zhat, inverse=True)[0]
                        batch.append(x.cpu())
                    batch = torch.cat(batch, dim=0)
                    assert len(batch) == 32
                    steps_to_samples[step][temp] = batch
            print(f' --> sampled at step {step}')
            if use_ema:
                ema.restore(cond_model)

    print()
    cond_model.eval()
    if use_ema:
        ema.assign(cond_model)

    return cond_model, steps_to_samples, stats


def run_naive_vi(*,
                 base_model: RealNVP,
                 vi_model: RealNVP,
                 x_star: torch.Tensor,
                 forward_op: ForwardOp,
                 device: torch.device,
                 sigma: float,
                 batch_size: int,
                 learning_rate: float,
                 num_steps: int,
                 temps,
                 use_ema: bool,
                 sample_steps,
                 kl_warmup,
                 ):
    if x_star.ndim == 3:
        x_star.unsqueeze_(0)
    assert x_star.ndim == 4

    if use_ema:
        ema = EMA(vi_model, decay=0.95)

    # Get target observation
    x_star_obs, _ = forward_op.observe(x_star)

    # Create model and optimizer
    optimizer = torch.optim.Adam(vi_model.parameters(), lr=learning_rate)

    # Train
    stats = SimpleNamespace(
        loss                    = [],
        loss_kl                 = [],
        loss_rec                = [],

        steps_per_sec           = [],
        total_time              = [],

        x_star                  = x_star.cpu(),
    )

    steps_to_samples = {}
    start_time = time.time()
    N = batch_size
    for step in range(1, num_steps + 1):
        if kl_warmup is None:
            kl_coeff = 1.0
        else:
            kl_coeff = min(float(step) / kl_warmup, 1.0)

        z = vi_model.sample_prior(N, device=device)
        try:
            x, logdet = vi_model(z, inverse=True)
        except AssertionError as e:
            print(f'AssertionError caught!  {e}')
            return None, None, None
        x = x.clamp(0, 1)

        # Loss: KL term
        log_qx = vi_model.log_prior(z) - logdet
        log_px = base_model.log_prob(x)[0]
        assert log_qx.shape == log_px.shape == (N,)
        loss_kl = log_qx - log_px

        # Loss: Reconstruction
        x_obs, _ = forward_op.observe(x)
        assert x_obs.shape[1:] == x_star_obs.shape[1:]
        loss_rec = (x_obs - x_star_obs).reshape(N, -1).pow(2).sum(dim=1)

        assert loss_kl.shape == loss_rec.shape == (N,)
        loss_kl = loss_kl.mean()
        loss_rec = loss_rec.mean()
        loss = kl_coeff * loss_kl + loss_rec / (2.0 * sigma**2)
        optimizer.zero_grad()
        loss.backward()
        # grad_norm = torch.nn.utils.clip_grad_norm_(vi_model.parameters(), 1e4)
        optimizer.step()
        if use_ema:
            ema(vi_model, num_updates=step)

        # Monitoring
        total_time = time.time() - start_time
        stats.loss.append(loss.item())
        stats.loss_kl.append(loss_kl.item())
        stats.loss_rec.append(loss_rec.item())
        stats.steps_per_sec.append(step / total_time)
        stats.total_time.append(total_time)

        print(f'\rstep {step:05d}/{num_steps:05d} '
              f'loss {np.mean(stats.loss[-10:]):.3f} ',
              f'loss_kl {np.mean(stats.loss_kl[-10:]):.3f} ',
              f'loss_rec {np.mean(stats.loss_rec[-10:]):.3f} ',
              f'kl_coeff {kl_coeff:.3f} ',
              f'log_qx {log_qx.mean().item():.4f} '
              f'log_px {log_px.mean().item():.4f} '
              # f'grad_norm {grad_norm:.4f} '
              f'time {stats.total_time[-1]:.2f} sec ',
              f'steps/sec {stats.steps_per_sec[-1]:.2f} ', end='')

        # Get samples
        if step in sample_steps or step == num_steps:
            vi_model.eval()
            if use_ema:
                ema.assign(vi_model)
            steps_to_samples[step] = {}
            with torch.no_grad():
                for temp in temps:
                    batch = []
                    for _ in range(4):
                        x = vi_model.sample(8, device=device, temp=temp)
                        batch.append(x.cpu())
                    batch = torch.cat(batch, dim=0)
                    assert len(batch) == 32
                    steps_to_samples[step][temp] = batch
            print(f' --> sampled at step {step}')
            if use_ema:
                ema.restore(vi_model)

    print()
    vi_model.eval()
    if use_ema:
        ema.assign(vi_model)

    return vi_model, steps_to_samples, stats
