from yaml import load, Loader
import argparse
import json
import random
import os
import os.path as osp

import torch
import torch.nn.functional as F
from accelerate.logging import get_logger
from accelerate import DistributedType
import kornia
from einops import rearrange
from diffusers.training_utils import compute_density_for_timestep_sampling, \
    compute_loss_weighting_for_sd3
from diffusers.utils.torch_utils import randn_tensor

from lib.unified_dataset_wm.dataset.ac_dataset import MixAC
from lib.pipelines.pipeline_cosmos2_canny_aug import CannyAugCosmos2Pipeline
from lib.trainers.base_trainer import State
from lib.trainers.action_cond_trainer import ACWMTrainer
from lib.utils.memory_utils import get_memory_statistics, free_memory
from lib.utils.misc import ProgressTracker
from lib.utils.torch_utils import apply_color_jitter_to_video, get_latents, unwrap_model, \
                                  save_video


logger = get_logger(__name__)
logger.setLevel('INFO')

SPATIAL_DOWN_RATIO = 8
TEMPORAL_DOWN_RATIO = 4


class CannyAugTrainer(ACWMTrainer):

    def __init__(self, config_file=None, args=None, checkpoint_root=None, val_only=False):

        if args is None:
            assert config_file is not None
            cd = load(open(config_file, 'r'), Loader=Loader)
            args = argparse.Namespace(**cd)

        if config_file is None:
            assert args is not None

        if checkpoint_root is not None:
            args.output_dir = checkpoint_root

        args.lr = float(args.lr)
        args.epsilon = float(args.epsilon)
        args.weight_decay = float(args.weight_decay)

        self.args = args
        self.state = State

        # Tokenizers
        self.tokenizer = None

        # Text encoders
        self.text_encoder = None

        # Denoisers
        self.transformer = None

        # Autoencoders
        self.vae = None

        # Scheduler
        self.scheduler = None

        self._init_distributed()
        if not val_only:
            self._init_logging()
            self._init_directories_and_repositories()

    def prepare_dataset(self):
        logger.info("Initializing AgiBot dataset and dataloader")

        local_rank = int(os.environ.get("LOCAL_RANK", 0))

        # dataset_type = getattr(self.args, "dataset_type", "processed")
        # if dataset_type == "processed":
        #     dataset_func = Mix
        # elif dataset_type == "all":
        #     dataset_func= MixAll
        # else:
        #     raise NotImplementedError

        dataset_func = MixAC

        dataset_config = self.args.data['train']
        if getattr(self.args, "multi_resolution", False):
            dataset_config['sample_size'] = self.args.resolution_list[local_rank % len(self.args.resolution_list)]
            self.cur_batch_size = self.args.batch_size_list[local_rank % len(self.args.resolution_list)]
        else:
            self.cur_batch_size = self.args.batch_size
        self.train_dataset = dataset_func(gpuid=local_rank, **dataset_config)
    
        if self.train_dataset.decode_type == 'gpu':
            import multiprocessing as mp
            spawn_ctx = mp.get_context('spawn')
        else:
            spawn_ctx = None
        logger.info(f'>>>>>>>>>>>>>>>>>Video decode mode:{self.train_dataset.decode_type}<<<<<<<<<<<<<<<<<<<<')
        self.train_dataloader = torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            shuffle=True,
            batch_size=self.cur_batch_size,
            num_workers=self.args.dataloader_num_workers,
            multiprocessing_context=spawn_ctx,
        )

        logger.info(f">>>>>>>>>>>>>Total eps:{len(self.train_dataset)}<<<<<<<<<<<<<<<<<<")
        if 'val' in self.args.data and getattr(self.args, "load_val", True):
            self.val_dataset = dataset_func(**self.args.data['val'])

            self.val_index = []
            for _ in range(self.args.batch_size):
                self.val_index.append(random.randint(0, len(self.val_dataset)-1))
            if self.state.accelerator.is_main_process:
                with open(os.path.join(self.save_folder, 'idx.txt'), "w") as file:
                    file.write(", ".join(map(str, self.val_index)))

            subset = torch.utils.data.Subset(self.val_dataset, self.val_index)

            # DataLoader
            self.val_dataloader = torch.utils.data.DataLoader(subset, batch_size=self.args.batch_size, shuffle=getattr(self.args, "val_shuffle", False))


    def train(self):
        logger.info("Starting training")
        memory_statistics = get_memory_statistics()
        logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")

        self.state.train_batch_size = (
            self.cur_batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps
        )
        info = {
            "trainable parameters": self.state.num_trainable_parameters,
            "total samples": len(self.train_dataset),
            "train epochs": self.state.train_epochs,
            "train steps": self.state.train_steps,
            "batches per device": self.cur_batch_size,
            "total batches observed per epoch": len(self.train_dataloader),
            "train batch size": self.state.train_batch_size,
            "gradient accumulation steps": self.args.gradient_accumulation_steps,
        }
        logger.info(f"Training configuration: {json.dumps(info, indent=4)}")

        global_step = 0
        first_epoch = 0

        accelerator = self.state.accelerator
        weight_dtype = self.state.weight_dtype
        scheduler_sigmas = self.scheduler.sigmas.clone().to(device=accelerator.device, dtype=weight_dtype)  # pyright: ignore
        generator = torch.Generator(device=accelerator.device)
        if self.args.seed is not None:
            generator = generator.manual_seed(self.args.seed)
        self.state.generator = generator

        # null prompt for classifier-free guidance
        negative_prompt = ""
        negative_prompt_embed = self.encode_prompt(negative_prompt)

        # ??? Acquired by cosmos transformers. looks like it's always 0
        h, w = self.args.data['train']['sample_size']  # 384, 512
        padding_mask = torch.zeros(1, 1, h, w, device=accelerator.device, dtype=weight_dtype)  # ???

        tracker = ProgressTracker(self.state.train_steps, description='Training Iterations')
        tracker.start()
        for epoch in range(first_epoch, self.state.train_epochs):
            logger.info(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")
            
            self.transformer.train()  # pyright: ignore

            running_loss = 0.0

            for step, batch in enumerate(self.train_dataloader):
                logger.debug(f"Starting step {step + 1}")
                logs = {}
                model_list = [self.transformer, ]

                with accelerator.accumulate(model_list):
                    video = batch['video']

                    # shape b, c, v, t, h, w ranging from -1 to 1
                    video = video.to(accelerator.device, dtype=weight_dtype).contiguous()
                    batch_size, c, n_view, t, h, w = video.shape

                    # # pix to pix condition model
                    # canny = self.pix_wise_cond_model(video)  # b, 1, v, t, h, w

                    video = rearrange(video, 'b c v t h w -> (b v) c t h w')
                    if random.random() <= self.args.use_color_jitter:
                        video = apply_color_jitter_to_video(video, same_jitter_within_view=True, n_view=n_view)

                    # slice out memory & future
                    mem_size = self.args.data['train']['n_previous']
                    mem = video[:, :, :mem_size]
                    future_video = video[:, :, mem_size:]

                    if self.train_dataset.ignore_seek:  # in this case the future frame of video is not provided
                        future_video = future_video.repeat(1,1,self.args.data['train']['chunk'],1,1)

                    # get the shape params
                    _, _, raw_frames, raw_height, raw_width = future_video.shape
                    latent_channels = self.vae.z_dim  # pyright: ignore
                    latent_frames = raw_frames // TEMPORAL_DOWN_RATIO + 1  # future only
                    latent_height = raw_height // SPATIAL_DOWN_RATIO
                    latent_width = raw_width // SPATIAL_DOWN_RATIO

                    with torch.no_grad():
                        # vae encode && reshape
                        # time shrink for future video, but not for mem
                        mem_latents, future_video_latents = get_latents(self.vae, mem, future_video)[:2]  # pyright: ignore

                    mem_latents = rearrange(mem_latents, '(b v m) (h w) c -> (b v) c m h w', b=batch_size, m=mem_size, h=latent_height)
                    future_video_latents = rearrange(future_video_latents, '(b v) (f h w) c -> (b v) c f h w',b=batch_size,h=latent_height,w=latent_width)

                    # concat memory and future video
                    latents = torch.cat((mem_latents, future_video_latents), dim=2)  # bv c m+f h w

                    # get canny
                    canny = self.get_video_canny(mem, future_video, 
                                                 mem_size, latent_frames,
                                                 latent_height, latent_width)  # bv c=1 m+f h w

                    # gen noise
                    noise, cond_indicator, conditioning_mask = self.gen_noise_from_condition_frame_latent(
                            mem_latents, latent_frames, latent_height, latent_width, 
                            noise_to_condition_frames=self.args.noise_to_first_frame,
                            device=accelerator.device, dtype=weight_dtype
                        )  # bv c m+f h w

                    # encode text prompt
                    prompt_embeds = self.encode_prompt(batch['caption'])
                    if self.args.train_w_cfg:
                        # dropout a portion of the prompts for classifier-free guidance
                        dropout_factor = torch.rand(batch_size).to(accelerator.device, dtype=weight_dtype)
                        dropout_mask_prompt = dropout_factor < self.args.caption_dropout_p
                        dropout_mask_prompt = dropout_mask_prompt.unsqueeze(1).unsqueeze(2)
                        prompt_embeds = negative_prompt_embed.repeat(batch_size,1,1) * dropout_mask_prompt + \
                                        prompt_embeds * ~dropout_mask_prompt

                    # choose timesteps
                    timestep_weights = compute_density_for_timestep_sampling(
                        weighting_scheme=self.args.flow_weighting_scheme,
                        batch_size=batch_size,
                        logit_mean=self.args.flow_logit_mean,
                        logit_std=self.args.flow_logit_std,
                        mode_scale=self.args.flow_mode_scale,
                    )
                    timestep_weights = timestep_weights.unsqueeze(1).repeat(1,n_view)
                    timestep_weights = timestep_weights.reshape(-1)
                    indices = (timestep_weights * self.scheduler.config.num_train_timesteps).long()  # pyright: ignore
                    sigmas = scheduler_sigmas[indices]  # ranges from 0 to 1
                    # ATTN! cosmos-predict2 uses a different timestep conditioning
                    # timesteps = (sigmas * 1000.0).long()  # LTX: ranges from 0 to 1000
                    timesteps = sigmas  # Cosmos2: ranges from 0 to 1
                    # The sigma value used for scaling conditioning latents. Ideally, it should not be changed or should be
                    # set to a small value close to zero.
                    sigma_conditioning = torch.tensor(0.0001, dtype=torch.float32, device=accelerator.device)
                    t_conditioning = sigma_conditioning / (sigma_conditioning + 1)  # a very small number
                    cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timesteps.view(-1, 1, 1, 1, 1)

                    # adding noise to latents
                    ss = sigmas.reshape(-1, 1, 1, 1, 1)
                    noisy_latents = (1.0 - ss) * latents + ss * noise  # bv c m+f h w

                    # set hand-view memory to all-zero
                    noisy_latents = rearrange(noisy_latents, '(b v) c t h w -> b c v t h w', v=n_view)
                    noisy_latents[:, :, 1:, :mem_size] = 0.0
                    noisy_latents = rearrange(noisy_latents, 'b c v t h w -> (b v) c t h w')

                    # loss weight
                    loss_weights = compute_loss_weighting_for_sd3(
                        weighting_scheme=self.args.flow_weighting_scheme, sigmas=sigmas).reshape(-1, 1, 1, 1, 1)

                    # concat canny
                    noisy_latents = torch.cat([noisy_latents, canny], dim=1)

                    pred = self.transformer(  # pyright: ignore
                        hidden_states=noisy_latents,
                        timestep=cond_timestep,
                        encoder_hidden_states=prompt_embeds,
                        fps=self.args.data['train']['fps'],
                        condition_mask=conditioning_mask,
                        padding_mask=padding_mask,
                        return_dict=False,
                        n_view=n_view
                    )[0]  # bv vae_z_dim m+f h w

                    # TODO: we may need to linearly interpolate noise_pred with noisy_latents,
                    # but assume the interpolation as None before NVIDIA releases training code
                    # noise_pred = noise_pred + noisy_latents
                    target = noise - latents

                    loss_video = loss_weights.float() * (pred.float() - target.float()).pow(2)
                    # loss_video = loss_video * (1 - conditioning_mask)
                    loss_video = loss_video[:, :, mem_size:]

                    # # Average loss across channel dimension
                    loss_video = loss_video.mean(list(range(1, loss_video.ndim)))
                    # Average loss across batch dimension
                    loss_video = loss_video.mean()

                    assert torch.isnan(loss_video) == False, "NaN loss detected"
                    accelerator.backward(loss_video)
                    if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED:
                        grad_norm = accelerator.clip_grad_norm_(self.transformer.parameters(), self.args.max_grad_norm)  # pyright: ignore
                        logs["grad_norm"] = grad_norm

                    self.optimizer.step()  # pyright: ignore
                    self.lr_scheduler.step()  # pyright: ignore
                    self.optimizer.zero_grad()  # pyright: ignore

                # gather loss info outside of accelerator.accumulate
                loss_video = accelerator.reduce(loss_video.detach(), reduction='mean')

                running_loss += loss_video.item()
                # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    tracker.update()
                    global_step += 1

                logs = {
                    "loss": loss_video.detach().item(),
                    "lr": self.lr_scheduler.get_last_lr()[0],  # pyright: ignore
                }
                accelerator.log(logs, step=global_step)

                if global_step >= self.state.train_steps:
                    logger.info(">>> max train step reached")
                    break

                if global_step % self.args.steps_to_log == 0:
                    if accelerator.is_main_process:
                        self.writer.add_scalar("Training Loss", loss_video.item(), global_step)
                        print(f'loss: {logs["loss"]:.6f} lr: {logs["lr"]:.6f} | {tracker.get_progress_string()}')

                if self.args.load_val and global_step % self.args.steps_to_val == 0 or global_step == 1:
                    accelerator.wait_for_everyone()
                    # if accelerator.is_main_process:
                    with torch.no_grad():
                        model_save_dir = osp.join(self.save_folder, f'Validation_step_{global_step}')

                        _ = self.validate(
                            accelerator, model_save_dir, 
                            dataloader=self.train_dataloader, n_view=n_view,
                            fps=self.args.data['train']['fps']
                        )
                        _ = self.validate(
                            accelerator, model_save_dir, 
                            dataloader=self.val_dataloader, n_view=n_view,
                            fps=self.args.data['val']['fps']
                        )
                    accelerator.wait_for_everyone()

                if global_step % self.args.steps_to_save == 0:
                    
                    accelerator.wait_for_everyone()
                    if accelerator.is_main_process:
                        model_to_save = unwrap_model(accelerator, self.transformer)
                        model_save_dir = osp.join(self.save_folder,f'step_{global_step}')
                        os.makedirs(model_save_dir, exist_ok=True)
                        model_to_save.save_pretrained(model_save_dir, safe_serialization=True)
                        del  model_to_save

            # get mem info after each epoch
            memory_statistics = get_memory_statistics()
            logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")

            if accelerator.is_main_process:
                avg_loss = running_loss / len(self.train_dataloader)
                self.writer.add_scalar("Average Training Loss", avg_loss, epoch)

        # training finished, save final model
        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            self.transformer = unwrap_model(accelerator, self.transformer)
            model_save_dir = osp.join(self.save_folder,f'step_{global_step}')
            self.transformer.save_pretrained(model_save_dir, safe_serialization=True)

        del self.transformer, self.scheduler
        free_memory()
        memory_statistics = get_memory_statistics()
        logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")

        accelerator.end_training()

    @torch.no_grad()
    def get_video_canny(self, mem, future_video, mem_size, future_size, height, width):
        
        mem = rearrange(mem, 'bv c t h w -> (bv t) c h w')
        mem = F.interpolate(mem, (height, width), mode='bilinear')
        mem = (mem - mem.min()) / (mem.max() - mem.min())  # (-1, 1) -> (0, 1)
        _, canny_mem = kornia.filters.canny(mem.float())
        canny_mem = rearrange(canny_mem, '(bv t) c h w -> bv c t h w', t=mem_size)
        
        future_video = F.interpolate(future_video, (future_size, height, width), mode='trilinear')
        future_video = (future_video - future_video.min()) / (future_video.max() - future_video.min())  # (-1, 1) -> (0, 1)
        future_video = rearrange(future_video, 'bv c t h w -> (bv t) c h w')
        _, canny_future = kornia.filters.canny(future_video.float())
        canny_future = rearrange(canny_future, '(bv t) c h w -> bv c t h w', t=future_size)

        out = torch.cat([canny_mem, canny_future], dim=2)
        out = (out * 2.0) - 1.0  # (0, 1) -> (-1, 1)
        return out.to(mem)

    def gen_noise_from_condition_frame_latent(self, mem_latents, future_latent_frames,  # pyright: ignore
                                              latent_height, latent_width,
                                              noise_to_condition_frames=0.2,
                                              noise=None, generator=None,
                                              device='cuda', dtype=torch.bfloat16,
                                              view_1_only=True, n_view=3
    ):
        '''
        mem_latents: (b v) c m h w
        future_latent_frames: number of future latent frames
        '''
        mem_size = mem_latents.shape[2]
        num_channels_latents = mem_latents.shape[1]
        batch_size = mem_latents.shape[0]   # bv

        # noise with latent shape
        shape = (batch_size, num_channels_latents, mem_size+future_latent_frames, latent_height, latent_width)
        if noise is None:
            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

        ### mem at view_0:1, pred: 0
        mask_shape = (batch_size, 1, mem_size+future_latent_frames, latent_height, latent_width)
        conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
        conditioning_mask[:, :, :mem_size] = 1.0
        if view_1_only:
            conditioning_mask = rearrange(conditioning_mask, '(b v) c f h w -> b v c f h w', v=n_view)
            # conditioning_mask[:, :1, :mem_size] = 1.0
            conditioning_mask[:, 1:] = 0.0
            conditioning_mask = rearrange(conditioning_mask, 'b v c f h w -> (b v) c f h w', v=n_view)

        # similar to conditioning mask but useful to timesteps
        cond_indicator = torch.zeros((batch_size, 1, mem_size+future_latent_frames, 1, 1), device=device, dtype=dtype)
        cond_indicator[:, :, :mem_size] = 1.0
        if view_1_only:
            cond_indicator = rearrange(cond_indicator, '(b v) c f h w -> b v c f h w', v=n_view)
            cond_indicator[:, 1:] = 0.0
            cond_indicator = rearrange(cond_indicator, 'b v c f h w -> (b v) c f h w', v=n_view)

        # if view_1_only:
        #     mem_latents = rearrange(mem_latents, '(b v) c m h w -> b v c m h w', v=n_view)
        #     hand_view_noise_shape = list(mem_latents.shape)
        #     hand_view_noise_shape[1] -= 1  # 3 -> 2
        #     hand_view_noise = randn_tensor(hand_view_noise_shape, generator=generator, device=device, dtype=dtype)
        #     mem_latents = torch.cat([mem_latents[:, :1], hand_view_noise], dim=1)
        #     mem_latents = rearrange(mem_latents, 'b v c m h w -> (b v) c m h w')

        # 因为最开始生成数据集时不小心pack了第一帧所以现在多做这一步
        init_latents = mem_latents[:,:,-1:].repeat(1, 1, mem_size+future_latent_frames, 1, 1)
        init_latents[:,:,:mem_size] = mem_latents

        if noise_to_condition_frames > 0: 
            rand_noise_ff_s = torch.rand(batch_size) * noise_to_condition_frames
            rand_noise_ff_e = torch.rand(batch_size) * noise_to_condition_frames
            rand_noise_ff_s, rand_noise_ff_e = torch.minimum(rand_noise_ff_s, rand_noise_ff_e), torch.maximum(rand_noise_ff_s, rand_noise_ff_e)
            rand_noise_ff = torch.stack([torch.linspace(rand_noise_ff_s[_], rand_noise_ff_e[_], mem_size) for _ in range(batch_size)], dim=0)
            rand_noise_ff = rand_noise_ff.reshape(batch_size, 1, mem_size, 1, 1).to(dtype=dtype, device=device)        
            first_frame_mask = conditioning_mask.clone()
            first_frame_mask[:, :, :mem_size] = 1.0 - rand_noise_ff
        else:
            first_frame_mask = conditioning_mask.clone()

        if view_1_only:
            first_frame_mask = rearrange(first_frame_mask, '(b v) c f h w -> b v c f h w', v=n_view)
            # first_frame_mask[:, :1, :mem_size] = 1.0 - rand_noise_ff
            first_frame_mask[:, 1:] = 0.0
            first_frame_mask = rearrange(first_frame_mask, 'b v c f h w -> (b v) c f h w', v=n_view)

        # before mem_size it's memory info; after mem_size it's noise
        noise_mask = 1 - first_frame_mask
        latents = init_latents * first_frame_mask + noise * noise_mask

        return latents, cond_indicator, conditioning_mask


    @torch.no_grad()
    def validate(self, accelerator, model_save_dir,  # pyright: ignore
                 dataloader=None, video=None, video_for_canny=None, prompt=None,
                 n_prev=None, n_view=3, chunk_size=None,
                 merge_view_into_width=True, fps=30, video_path=None,
                 vis_cat_gt=True, write_video_to_disk=True,
                 guidance_scale=1.0, save_tag='', pipeline_progress=True,
                 compile_transformer=False, num_inference_steps=35
    ):

        os.makedirs(model_save_dir, exist_ok=True)

        transformer = unwrap_model(accelerator, self.transformer)
        if compile_transformer:
            print('[INFO] Compiling transformer. This may take a long time.')
            transformer = torch.compile(self.transformer)
            print('[INFO] Compiling transformer finished.')

        pipe = CannyAugCosmos2Pipeline(
            self.text_encoder, self.tokenizer, transformer,
            self.vae, self.scheduler)  # pyright: ignore

        if isinstance(dataloader, torch.utils.data.dataloader.DataLoader):
            dataloader = iter(dataloader)  # pyright: ignore

        if video is None:
            assert dataloader is not None
            batch = next(dataloader)
            video = batch['video'].to(accelerator.device, dtype=self.state.weight_dtype)
            # video = video[:,:,:,:self.args.data['train']['n_previous']]
            prompt = batch['caption']
            video_path = batch['path']
            gt_video = video.clone()
        else:
            gt_video = video.clone()
            assert prompt is not None
            assert video_path is not None
        if video_path[0].endswith('/{}'):
            video_path[0] = video_path[0].replace('/{}', '')

        b, c, n_view, _, h, w = video.shape
        n_prev = n_prev if n_prev is not None else self.args.data['val']['n_previous']
        chunk_size = chunk_size if chunk_size is not None else self.args.data['train']['chunk']
        fps = fps if fps is not None else self.args.data['val']['fps']

        if video_for_canny is None:  # video: edited appearance, video_for_canny: full length original video
            video_for_canny = video.clone()
        video_for_canny = rearrange(video_for_canny, 'b c v t h w -> (b v) c t h w')
        canny = self.get_video_canny(video_for_canny[:, :, :n_prev], video_for_canny[:, :, n_prev:],
                                     n_prev, chunk_size//TEMPORAL_DOWN_RATIO+1,
                                     h//SPATIAL_DOWN_RATIO, w//SPATIAL_DOWN_RATIO)  # bv c t h w
        video = rearrange(video, 'b c v t h w -> (b v) t c h w')
        negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."

        preds = pipe(
            video=video,
            cond_to_concat=canny,  # pyright: ignore
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=num_inference_steps,
            height=384, width=512, num_frames=chunk_size,
            n_view=n_view, n_prev=n_prev,
            guidance_scale=guidance_scale,
            merge_view_into_width=merge_view_into_width, fps=fps, 
            postprocess_video=False,
            show_progress=pipeline_progress,
            )['frames']  # (b v) c t h w or b c t h (v w), range -1 to 1 (could exceed range)

        if vis_cat_gt:
            if merge_view_into_width:
                gt_video = rearrange(gt_video, 'b c v t h w -> b c t h (v w)')[:, :, n_prev:]
                preds = torch.cat([preds, gt_video], dim=3)  # concat on height
            else:
                # TODO
                raise NotImplementedError

        if write_video_to_disk:
            save_tag = f'_{save_tag}_'
            if merge_view_into_width == True:
                save_name = osp.join(model_save_dir, video_path[0].split('/')[-1] + save_tag + '.mp4')
                save_video(preds[0], save_name)
                print(f'Result saved to {save_name}')
            else:
                for iv in range(len(preds)):  # (b v) c t h w
                    save_name = osp.join(model_save_dir, 
                                         video_path[0].split('/')[-1] + save_tag + f'_view{iv}.mp4')
                    save_video(preds[iv], save_name, fps=fps)
                    print(f'Result saved to {save_name}')

        return preds