import os
import os.path as osp
from yaml import load, Loader
import argparse

import torch
import torch.nn.functional as F
import kornia

from accelerate.logging import get_logger
from einops import rearrange

from lib.trainers.canny_aug_trainer import CannyAugTrainer
from lib.trainers.base_trainer import State
from lib.pipelines.pipeline_cosmos2_canny_aug_full_view import CannyAugFullViewCosmos2Pipeline
from lib.utils.torch_utils import unwrap_model, save_video

logger = get_logger(__name__)
logger.setLevel('INFO')

SPATIAL_DOWN_RATIO = 8
TEMPORAL_DOWN_RATIO = 4

class CannyAugFullViewTrainer(CannyAugTrainer):

    @torch.no_grad()
    def validate(self, accelerator, model_save_dir,
                 dataloader=None, video=None, video_for_canny=None, prompt=None,
                 n_prev=None, n_view=3, chunk_size=None,
                 merge_view_into_width=True, fps=30, video_path=None,
                 vis_cat_gt=True, write_video_to_disk=True,
                 guidance_scale=1.0, save_tag='', pipeline_progress=True
    ):
        os.makedirs(model_save_dir, exist_ok=True)
        pipe = CannyAugFullViewCosmos2Pipeline(
            self.text_encoder, self.tokenizer, unwrap_model(accelerator, self.transformer),
            self.vae, self.scheduler)

        if isinstance(dataloader, torch.utils.data.dataloader.DataLoader):
            dataloader = iter(dataloader)

        if video is None:
            assert dataloader is not None
            batch = next(dataloader)
            video = batch['video'].to(accelerator.device, dtype=self.state.weight_dtype)
            prompt = batch['caption']
            video_path = batch['path']
            gt_video = video.clone()
        else:
            gt_video = video.clone()
            assert prompt is not None
            assert video_path is not None

        b, c, n_view, _, h, w = video.shape
        n_prev = n_prev if n_prev is not None else self.args.data['val']['n_previous']
        chunk_size = chunk_size if chunk_size is not None else self.args.data['train']['chunk']
        fps = fps if fps is not None else self.args.data['val']['fps']

        if video_for_canny is None:
            video_for_canny = video.clone()
        video_for_canny = rearrange(video_for_canny, 'b c v t h w -> (b v) c t h w')
        canny = self.get_video_canny(
            video_for_canny[:, :, :n_prev], video_for_canny[:, :, n_prev:],
            n_prev, chunk_size // TEMPORAL_DOWN_RATIO + 1,
            h // SPATIAL_DOWN_RATIO, w // SPATIAL_DOWN_RATIO
        )
        video = rearrange(video, 'b c v t h w -> (b v) t c h w')
        negative_prompt = (
            "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
            "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
            "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
            "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
            "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
            "Overall, the video is of poor quality."
        )

        preds = pipe(
            video=video,
            cond_to_concat=canny,
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=384, width=512, num_frames=chunk_size,
            n_view=n_view, n_prev=n_prev,
            guidance_scale=guidance_scale,
            merge_view_into_width=merge_view_into_width, fps=fps,
            postprocess_video=False,
            show_progress=pipeline_progress,
        )['frames']

        if vis_cat_gt:
            if merge_view_into_width:
                gt_video = rearrange(gt_video, 'b c v t h w -> b c t h (v w)')[:, :, n_prev:]
                preds = torch.cat([preds, gt_video], dim=3)  # concat on height
            else:
                raise NotImplementedError

        if write_video_to_disk:
            save_tag = f'_{save_tag}_'
            if merge_view_into_width:
                save_name = osp.join(model_save_dir, video_path[0].split('/')[-1] + save_tag + '.mp4')
                save_video(preds[0], save_name)
                print(f'Result saved to {save_name}')
            else:
                for iv in range(len(preds)):
                    save_name = osp.join(model_save_dir,
                                         video_path[0].split('/')[-1] + save_tag + f'_view{iv}.mp4')
                    save_video(preds[iv], save_name, fps=fps)
                    print(f'Result saved to {save_name}')

        return preds