from typing import Any
from einops import rearrange
import torch
from algorithms.common.base_pytorch_algo import BasePytorchAlgo
from .model import GaussianDiffusion, Unet
from .df_unet import DFUnetWrapper
from utils.logging_utils import log_video


class ImageDiffusion(BasePytorchAlgo):
    def __init__(self, cfg):
        self.cfg = cfg
        self.frame_stack = cfg.frame_stack
        self.network_size = cfg.network_size
        super().__init__(cfg)

    def _build_model(self):
        # unet = Unet(dim=32, dim_mults=(1, 2, 4, 8), flash_attn=True)
        unet = DFUnetWrapper(
            x_channel=3 * self.frame_stack, z_channel=16, network_size=self.network_size, num_gru_layers=0
        )
        self.diffusion = GaussianDiffusion(
            unet,
            image_size=self.cfg.resolution,
            sampling_timesteps=self.cfg.sampling_timesteps,
            objective=self.cfg.objective,
        )

    def _preprocess_batch(self, batch):
        xs = batch[0]
        xs = rearrange(xs, "b (t fs) c ... -> (b t) (fs c) ...", fs=self.frame_stack)

        return xs

    def training_step(self, batch, batch_idx):
        batch_size = batch[0].shape[0]
        xs = self._preprocess_batch(batch)
        loss, xs_pred = self.diffusion(xs, return_x_start=True)
        xs = rearrange(xs, "(b t) (fs c) ... -> (t fs) b c ...", fs=self.frame_stack, b=batch_size)
        xs_pred = rearrange(xs_pred, "(b t) (fs c) ... -> (t fs) b c ...", fs=self.frame_stack, b=batch_size)

        if batch_idx % 100 == 0:
            self.log_dict({"training/loss": loss})

        if batch_idx % 5000 == 0:
            log_video(
                xs_pred,
                xs,
                step=self.global_step,
                namespace="training_vis",
                logger=self.logger.experiment,
            )

        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        batch_size = batch[0].shape[0]
        xs = self._preprocess_batch(batch)
        xs_pred = self.diffusion.sample(batch_size=xs.shape[0])

        xs = rearrange(xs, "(b t) (fs c) ... -> (t fs) b c ...", fs=self.frame_stack, b=batch_size)
        xs_pred = rearrange(xs_pred, "(b t) (fs c) ... -> (t fs) b c ...", fs=self.frame_stack, b=batch_size).clone()

        log_video(
            xs_pred,
            xs,
            step=self.global_step,
            namespace="validation_vis",
            logger=self.logger.experiment,
        )
