import logging
import os
from pathlib import Path
from typing import List, Set

from PIL import Image
import cv2
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import torch
import torch.distributed as dist
from torchvision.transforms.functional import resize
import wandb
import wandb.util
from moviepy import ImageSequenceClip

from mdt.utils.utils import add_text

log = logging.getLogger(__name__)

flatten = lambda t: [item for sublist in t for item in sublist]
flatten_list_of_dicts = lambda t: {k: v for d in t for k, v in d.items()}


def _unnormalize(img):
    return img / 2 + 0.5


def delete_tmp_video(path):
    try:
        os.remove(path)
    except FileNotFoundError:
        pass


def add_modality(tasks, mod):
    return {f"{mod}/{task}" for task in tasks}


class RolloutVideo:
    def __init__(self, logger, empty_cache, log_to_file, save_dir, resolution_scale=1):
        self.videos = []
        self.video_paths = {}
        self.tags = []
        self.captions = []
        self.logger = logger
        self.empty_cache = empty_cache
        self.log_to_file = log_to_file
        self.save_dir = Path(save_dir)
        self.sub_task_beginning = 0
        self.step_counter = 0
        self.resolution_scale = resolution_scale
        if self.log_to_file:
            os.makedirs(self.save_dir, exist_ok=True)
        if (
            isinstance(self.logger, TensorBoardLogger)
            and dist.is_available()
            and dist.is_initialized()
            and not self.log_to_file
        ):
            log.warning("Video logging with tensorboard and ddp can lead to OOM errors.")

    def new_video(self, tag: str, caption: str = None) -> None:
        """
        Begin a new video with the first frame of a rollout.
        Args:
             tag: name of the video
             caption: caption of the video
        """
        # (1, 1, channels, height, width)
        self.videos.append(torch.Tensor())
        self.tags.append(tag)
        self.captions.append(caption)
        self.step_counter = 0
        self.sub_task_beginning = 0

    def draw_outcome(self, successful):
        """
        Draw red or green border around video depening on successful execution
        and repeat last frames.
        Args:
            successful: bool
        """
        c = 1 if successful else 0
        not_c = list({0, 1, 2} - {c})
        border = 10
        frames = 5
        self.videos[-1][:, -1:, c, :, :border] = 1
        self.videos[-1][:, -1:, not_c, :, :border] = 0
        self.videos[-1][:, -1:, c, :, -border:] = 1
        self.videos[-1][:, -1:, not_c, :, -border:] = 0
        self.videos[-1][:, -1:, c, :border, :] = 1
        self.videos[-1][:, -1:, not_c, :border, :] = 0
        self.videos[-1][:, -1:, c, -border:, :] = 1
        self.videos[-1][:, -1:, not_c, -border:, :] = 0
        repeat_frames = torch.repeat_interleave(self.videos[-1][:, -1:], repeats=frames, dim=1)
        self.videos[-1] = torch.cat([self.videos[-1], repeat_frames], dim=1)
        self.step_counter += frames

    def new_subtask(self):
        self.sub_task_beginning = self.step_counter

    def update(self, rgb_obs: np.ndarray) -> None:
        """
        Add new frame to video.
        Args:
            rgb_obs: numpy array, shape [H, W, C], dtype uint8 or float32
        """
        # Ensure input is float32 in [0, 1] or [0, 255]
        if rgb_obs.dtype == np.uint8:
            img = rgb_obs.astype(np.float32) / 255.0
        else:
            img = rgb_obs.astype(np.float32)
            if img.max() > 1.1:  # likely [0,255]
                img = img / 255.0

        # [H, W, C] -> [1, 1, C, H, W]
        img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).unsqueeze(0)  # [1,1,C,H,W]
        self.videos[-1] = torch.cat([self.videos[-1], img], dim=1)  # No unnormalize
        self.step_counter += 1

    def add_goal_thumbnail(self, goal_img):
        size = self.videos[-1].shape[-2:]
        i_h = int(size[0] / 3)
        i_w = int(size[1] / 3)
        img = resize(_unnormalize(goal_img.detach().cpu()), [i_h, i_w])
        self.videos[-1][:, self.sub_task_beginning :, ..., -i_h:, :i_w] = img

    def add_language_instruction(self, instruction):
        img_text = np.zeros(self.videos[-1].shape[2:][::-1], dtype=np.uint8) + 127
        add_text(img_text, instruction)
        img_text = ((img_text.transpose(2, 0, 1).astype(float) / 255.0) * 2) - 1
        self.videos[-1][:, self.sub_task_beginning :, ...] += torch.from_numpy(img_text)
        self.videos[-1] = torch.clip(self.videos[-1], -1, 1)

    def write_to_tmp(self):
        """
        In case of logging with WandB, save the videos as GIF in tmp directory,
        then log them at the end of the validation epoch from rank 0 process.
        """
        if isinstance(self.logger, WandbLogger) and not self.log_to_file:
            for video, tag in zip(self.videos, self.tags):
                video = np.clip(video.numpy() * 255, 0, 255).astype(np.uint8)
                wandb_vid = wandb.Video(video, fps=10, format="gif")
                self.video_paths[tag] = wandb_vid._path
            self.videos = []
            self.tags = []

    @staticmethod
    def _empty_cache():
        """
        Clear GPU reserved memory. Do not call this unnecessarily.
        """
        mem1 = torch.cuda.memory_reserved(dist.get_rank())
        torch.cuda.empty_cache()
        mem2 = torch.cuda.memory_reserved(dist.get_rank())
        log.info(f"GPU: {dist.get_rank()} freed {(mem1 - mem2) / 10**9:.1f}GB of reserved memory")

    def log(self, global_step: int) -> None:
        """
        Call this method at the end of a validation epoch to log videos to tensorboard, wandb or filesystem.
        Args:
            global_step: global step of the training
        """
        if self.log_to_file:
            self._log_videos_to_file(global_step)
        elif isinstance(self.logger, WandbLogger):
            self._log_videos_to_wandb()
        elif isinstance(self.logger, TensorBoardLogger):
            self._log_videos_to_tb(global_step)
        else:
            raise NotImplementedError
        self.videos = []
        self.tags = []
        self.captions = []
        self.video_paths = {}

    def _log_videos_to_tb(self, global_step):
        if dist.is_available() and dist.is_initialized():
            if self.empty_cache:
                self._empty_cache()

            all_videos = [None for _ in range(torch.distributed.get_world_size())]
            all_tags = [None for _ in range(torch.distributed.get_world_size())]
            try:
                torch.distributed.all_gather_object(all_videos, self.videos)
                torch.distributed.all_gather_object(all_tags, self.tags)
            except RuntimeError as e:
                log.warning(e)
                return
            # only log videos from rank 0 process
            if dist.get_rank() != 0:
                return
            videos = flatten(all_videos)
            tags = flatten(all_tags)

            for video, tag in zip(videos, tags):
                self._plot_video_tb(video, tag, global_step)
        else:
            for video, tag in zip(self.videos, self.tags):
                self._plot_video_tb(video, tag, global_step)

    def _plot_video_tb(self, video, tag, global_step):
        video = video.unsqueeze(0)
        self.logger.experiment.add_video(f"video{tag}", video, global_step=global_step, fps=10)

    def _log_videos_to_wandb(self):
        if dist.is_available() and dist.is_initialized():
            all_video_paths = [None for _ in range(torch.distributed.get_world_size())]
            all_captions = [None for _ in range(torch.distributed.get_world_size())]
            try:
                torch.distributed.all_gather_object(all_video_paths, self.video_paths)
                torch.distributed.all_gather_object(all_captions, self.captions)
            except RuntimeError as e:
                log.warning(e)
                return
            # only log videos from rank 0 process
            if dist.get_rank() != 0:
                return
            video_paths = flatten_list_of_dicts(all_video_paths)
            captions = flatten(all_captions)
        else:
            video_paths = self.video_paths
            captions = self.captions
        for (task, path), caption in zip(video_paths.items(), captions):
            self.logger.experiment.log({f"video{task}": wandb.Video(path, fps=20, format="gif", caption=caption)})
            delete_tmp_video(path)
    
    def _resize_video(self, video_tensor):
        """
        Rescale the video tensor using the provided scale.
        """
        t, h, w, c = video_tensor.shape
        new_h, new_w = int(h * self.resolution_scale), int(w * self.resolution_scale)
        
        # Resize each frame using moviepy's ImageClip
        resized_video = [ImageSequenceClip([frame], durations=[1]).resize(height=new_h, width=new_w).get_frame(0) for frame in video_tensor]
        return np.array(resized_video)

    def _log_videos_to_file(self, global_step, save_as_video=False):
        """
        Save videos to local files as .mp4 or .gif.
        """
        for video, tag in zip(self.videos, self.tags):
            # Ensure shape is [B, T, H, W, C] -> [T, H, W, C]
            if len(video.shape) == 4:
                video = video.unsqueeze(0)  # Add batch dim

            video = video.numpy()
            video = np.clip(video * 255, 0, 255).astype(np.uint8)

            tensor = self._prepare_video(video)  # You can leave this if it handles normalization/layout
            if self.resolution_scale != 1.0:
                tensor = self._resize_video(tensor)

            tag = tag.replace("/", "_")
            filename = self.save_dir / f"{tag}_{global_step}.{'mp4' if save_as_video else 'gif'}"

            # Save video
            clip = ImageSequenceClip(list(tensor), fps=30 if save_as_video else 20)
            if save_as_video:
                clip.write_videofile(str(filename), codec='libx264', bitrate="5000k")
            else:
                clip.write_gif(str(filename), logger=None)
    '''
    def _log_videos_to_file(self, global_step, save_as_video=False):
        """
        Mostly taken from WandB
        """
        for video, tag in zip(self.videos, self.tags):
            if len(video.shape) == 4:
                video = video.unsqueeze(0)
            video = np.clip(video.numpy() * 255, 0, 255).astype(np.uint8)

            mpy = wandb.util.get_module(
                "moviepy.editor",
                required='wandb.Video requires moviepy and imageio when passing raw data.  Install with "pip install moviepy imageio"',
            )
            tensor = self._prepare_video(video)
            # Resize tensor if resolution scale is not 1.0
            if self.resolution_scale != 1.0:
                tensor = self._resize_video(tensor)
            _, _height, _width, _channels = tensor.shape
            
            if save_as_video:
            # encode sequence of images into gif string
                clip = mpy.ImageSequenceClip(list(tensor), fps=30)
            else:
                clip = mpy.ImageSequenceClip(list(tensor), fps=20)

            tag = tag.replace("/", "_")
            if save_as_video:
                filename = str(self.save_dir / f"{tag}_{global_step}.mp4")
            else:
                filename = self.save_dir / f"{tag}_{global_step}.gif"
            if save_as_video:
                clip.write_videofile(filename, codec='libx264', bitrate="5000k")  # You can adjust the bitrate as needed
            else:
                clip.write_gif(filename, logger=None)
    '''    
    
    def save_frames_to_subfolder(self, n, rollout_index):
        # Ensure n is a valid number
        if n <= 0 or not isinstance(n, int):
            raise ValueError("n must be a positive integer.")

        # Create a new subfolder for the rollout
        subfolder_path = self.save_dir / f'rollout_{rollout_index}'
        os.makedirs(subfolder_path, exist_ok=True)

        # Iterate through all videos in self.videos
        for video_idx, video_tensor in enumerate(self.videos):
            # Assuming video_tensor shape is (1, t, c, h, w)
            _, total_frames, channels, height, width = video_tensor.shape

            # Create a sub-subfolder for each video
            video_subfolder_path = subfolder_path / f'video_{video_idx}'
            os.makedirs(video_subfolder_path, exist_ok=True)

            # Iterate through the video tensor and save every nth frame to the subfolder
            for frame_index in range(0, total_frames, n):
                frame = video_tensor[0, frame_index].permute(1, 2, 0).cpu().numpy()
                if channels == 1:  # If grayscale, remove the color dimension
                    frame = frame.squeeze(-1)
                frame_image = Image.fromarray((frame * 255).astype('uint8'))  # Assuming frame values are normalized
                frame_image.save(video_subfolder_path / f'frame_{frame_index}.png')

        print(f'Saved frames from {len(self.videos)} videos to {subfolder_path}')

    @staticmethod
    def _prepare_video(video):
        """This logic was mostly taken from tensorboardX"""
        if video.ndim < 4:
            raise ValueError("Video must be atleast 4 dimensions: time, channels, height, width")
        if video.ndim == 4:
            video = video.reshape(1, *video.shape)
        b, t, c, h, w = video.shape

        if video.dtype != np.uint8:
            logging.warning("Converting video data to uint8")
            video = video.astype(np.uint8)

        def is_power2(num):
            return num != 0 and ((num & (num - 1)) == 0)

        # pad to nearest power of 2, all at once
        if not is_power2(video.shape[0]):
            len_addition = int(2 ** video.shape[0].bit_length() - video.shape[0])
            video = np.concatenate((video, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)

        n_rows = 2 ** ((b.bit_length() - 1) // 2)
        n_cols = video.shape[0] // n_rows

        video = np.reshape(video, newshape=(n_rows, n_cols, t, c, h, w))
        video = np.transpose(video, axes=(2, 0, 4, 1, 5, 3))
        video = np.reshape(video, newshape=(t, n_rows * h, n_cols * w, c))
        return video
