import torch
import wandb
from torchvision.utils import make_grid, draw_segmentation_masks
from torchvision.transforms.functional import convert_image_dtype
from explainers.utils import normalize_attr
import math
from abc import ABC, abstractmethod
import numpy as np
from typing import Iterable
from .misc import generate_number_image


def convert_to_video(volume):
    # Accept a 5D tensor (N, C, H, W, D) 
    # Uses the depth as the time axis
    # Returns a 5D tensor (N, T, C, H, W)
    N, C, H, W, D = volume.shape
    volume = volume.permute(0, 4, 1, 2, 3) # (N, D, C, H, W)
    return volume


def convert_video_to_rgb(video):
    # Accepts a 5D tensor (N, T, C, H, W) [0, 1]
    # where C is 1 or 3
    # Returns a 5D tensor (N, T, 3, H, W) uint8 [0, 255]
    C = video.shape[2]
    if C == 1:
        video = video.repeat(1, 1, 3, 1, 1)
    if video.dtype != torch.uint8:
        video = (video * 255).byte()
    return video


def draw_segmask_overlay(img, mask, alpha=0.4):
    # Accepts a 4D uint8 tensor(N, C, H, W) img
    # and a 3D bool tensor (N, H, W) mask
    # Returns a 4D uint8 tensor (N, C, H, W) overlayed image
    color = torch.tensor([255, 255, 0], dtype=torch.uint8) # yellow
    img2 = img.permute(1,0,2,3).detach().clone()
    img2[:, mask] = color.to(img.device)[:, None]
    img2 = img2.permute(1,0,2,3)

    img = img * (1 - alpha) + img2 * alpha
    return img.byte()


def make_video_grid(video, nrow=None, padding=1, background=0):
    # Accept a 5D tensor (N, T, C, H, W) batch of videos
    # Returns a 4D tensor (T, C, H', W') grid of videos 
    # The device and the dtype of the video tensor is preserved
    N, T, C, H, W = video.shape
    video = video.permute(0, 1, 3, 4, 2) # (N, T, H, W, C)
    if nrow is None:
        nrow = math.ceil(math.sqrt(N))
    ncol = math.ceil(N / nrow)

    video_grid = torch.ones(
        T, (padding + H) * nrow + padding, (padding + W) * ncol + padding, C,
        dtype=video.dtype
    ).to(video.device) # (T, H', W', C)
    video_grid *= background

    for i in range(N):
        r = i // ncol
        c = i % ncol
        start_r = (padding + H) * r
        start_c = (padding + W) * c
        video_grid[:, start_r:start_r + H, start_c:start_c + W] = video[i]

    video_grid = video_grid.permute(0, 3, 1, 2) # (T, C, H', W')
    return video_grid


def log_videos(batch_idx, batch_imgs, tag):
    # Accept a 5D tensor (N, C, H, W, D) batch_imgs
    # Wandb accepts a 4D or 5D tensor (T, C, H, W) or (N, T, C, H, W)
    n_splits = batch_idx.shape[0]
    n_imgs_per_split = batch_imgs.shape[0] // n_splits
    
    for batch_imgs_ in batch_imgs.split(n_imgs_per_split):
        batch_imgs_ = convert_to_video(batch_imgs_)
        batch_imgs_ = convert_video_to_rgb(batch_imgs_)
        grid = make_video_grid(batch_imgs_).cpu().numpy()
        wandb.log({tag: wandb.Video(grid)})

    return grid


def log_attr_maps(batch_idx, batch_attrs):
    n_splits = batch_idx.shape[0]
    n_imgs_per_split = batch_attrs.shape[0] // n_splits

    for batch_attrs_ in batch_attrs.split(n_imgs_per_split):
        batch_attrs_pos = normalize_attr(batch_attrs_, 'positive').unsqueeze(1)
        batch_attrs_neg = normalize_attr(batch_attrs_, 'negative').unsqueeze(1)
        batch_attrs_abs = normalize_attr(batch_attrs_, 'absolute_value').unsqueeze(1)
        grid_pos = make_grid(batch_attrs_pos, normalize = True, scale_each = True)
        grid_neg = make_grid(batch_attrs_neg, normalize = True, scale_each = True)
        grid_abs = make_grid(batch_attrs_abs, normalize = True, scale_each = True)
        wandb.log({'attr_maps/positive': wandb.Image(grid_pos), 
                    'attr_maps/negative': wandb.Image(grid_neg),
                    'attr_maps/absolute': wandb.Image(grid_abs)})


def log_attr_maps_3d(batch_idx, batch_attrs):
    n_splits = batch_idx.shape[0]
    n_imgs_per_split = batch_attrs.shape[0] // n_splits

    for batch_attrs_ in batch_attrs.split(n_imgs_per_split):
        for attr_type in ['positive', 'negative', 'absolute_value']:
            batch_attrs_type = normalize_attr(batch_attrs_, attr_type).unsqueeze(1)
            batch_attrs_type = convert_to_video(batch_attrs_type)

            # Add noise to avoid identical frames
            # Because for some reason wandb doesn't log frames
            # which are to similar 
            batch_attrs_type += torch.randn_like(batch_attrs_type) * 0.002
            batch_attrs_type = torch.clamp(batch_attrs_type, 0, 1)
            batch_attrs_type = convert_video_to_rgb(batch_attrs_type)

            grid = make_video_grid(batch_attrs_type).cpu().numpy()
            wandb.log({f'attr_maps/{attr_type}': wandb.Video(grid)})


def log_attr_maps_3d_overlay(batch_idx, batch_imgs, batch_maps_post):
    for img, attr_map in zip(batch_imgs, batch_maps_post):
        img = convert_to_video(img.unsqueeze(0))
        img = convert_video_to_rgb(img)
        img_grid = make_video_grid(img)

        attr_map = convert_to_video(attr_map.unsqueeze(0))
        mask_grid = make_video_grid(attr_map) # (T, 1, H, W)
        mask_grid = mask_grid[:, 0] # (T, H, W) back to grayscale
        mask_grid = mask_grid.bool()

        grid = draw_segmask_overlay(img_grid, mask_grid).cpu().numpy()
        wandb.log({f'attr_maps_post_overlayed': wandb.Video(grid)})


def create_single_video_summary(batch_imgs, batch_masks, batch_inpaints):
    N = batch_imgs.shape[0]
    W = batch_imgs.shape[-2]
    img = convert_to_video(batch_imgs)
    img = convert_video_to_rgb(img)
    img_grid = make_video_grid(img, nrow=1, padding=2, background=255)
    T = img_grid.shape[0]

    attr_map = convert_to_video(batch_masks)
    mask_grid = make_video_grid(attr_map, nrow=1, padding=2, background=0) # (T, 1, H, W)
    mask_grid = mask_grid[:, 0] # (T, H, W) back to grayscale
    mask_grid = mask_grid.bool()

    overlay_grid = draw_segmask_overlay(img_grid, mask_grid)

    inpaint = convert_to_video(batch_inpaints)
    inpaint = convert_video_to_rgb(inpaint)
    inpaint_grid = make_video_grid(inpaint, nrow=1, padding=2, background=255)

    numbers = torch.stack([
        generate_number_image(
            n, width=W
        )
        for n in range(N)
    ]).to(inpaint_grid.device)
    numbers = (numbers.clamp(0, 1) * 255).byte()
    numbers = numbers.unsqueeze(1).repeat(1, T, 1, 1, 1)
    numbers_grid = make_video_grid(numbers, nrow=1, padding=2, background=255)

    video = torch.cat([
        numbers_grid, img_grid, overlay_grid, inpaint_grid
    ], dim=2)
    video = video.cpu().numpy()
    return video


def log_imgs(batch_idx, batch_imgs, tag):
    n_splits = batch_idx.shape[0]
    n_imgs_per_split = batch_imgs.shape[0] // n_splits
    
    for batch_imgs_ in batch_imgs.split(n_imgs_per_split):
        grid = make_grid(batch_imgs_.float())
        wandb.log({tag: wandb.Image(grid)})


def log_img_mask_overlay(batch_maps_post, batch_imgs):
    for img, attr_map in zip(batch_imgs, batch_maps_post):
        if img.shape[0] == 1: # Convert to RGB if Grayscale
            img = img.repeat(3, 1, 1)
        overlayed = draw_segmentation_masks(
            convert_image_dtype(img, torch.uint8), 
            masks=attr_map.bool(),
            colors="yellow", alpha=0.4)
        wandb.log({"attr_maps_post_overlayed": wandb.Image(overlayed.float())})


def create_single_video_summary(batch_imgs, batch_masks, batch_inpaints):
    N = batch_imgs.shape[0]
    W = batch_imgs.shape[-2]
    img = convert_to_video(batch_imgs)
    img = convert_video_to_rgb(img)
    img_grid = make_video_grid(img, nrow=1, padding=2, background=255)
    T = img_grid.shape[0]

    attr_map = convert_to_video(batch_masks)
    mask_grid = make_video_grid(attr_map, nrow=1, padding=2, background=0) # (T, 1, H, W)
    mask_grid = mask_grid[:, 0] # (T, H, W) back to grayscale
    mask_grid = mask_grid.bool()

    overlay_grid = draw_segmask_overlay(img_grid, mask_grid)

    inpaint = convert_to_video(batch_inpaints)
    inpaint = convert_video_to_rgb(inpaint)
    inpaint_grid = make_video_grid(inpaint, nrow=1, padding=2, background=255)

    numbers = torch.stack([
        generate_number_image(
            n, width=W
        )
        for n in range(N)
    ]).to(inpaint_grid.device)
    numbers = (numbers.clamp(0, 1) * 255).byte()
    numbers = numbers.unsqueeze(1).repeat(1, T, 1, 1, 1)
    numbers_grid = make_video_grid(numbers, nrow=1, padding=2, background=255)

    video = torch.cat([
        numbers_grid, img_grid, overlay_grid, inpaint_grid
    ], dim=2)
    video = video.cpu().numpy()
    return video




class UtilsLogger(ABC):

    def __init__(self):
        super().__init__()

    @abstractmethod
    def log_original(self, batch_idx, batch_imgs):
        ...

    @abstractmethod
    def log_original_mask_overlay(self, batch_idx, batch_imgs, batch_masks):
        ...

    @abstractmethod
    def log_attr_maps(self, batch_idx, batch_masks):
        ...

    @abstractmethod
    def log_attr_maps_post(self, batch_idx, batch_masks):
        ...
    
    @abstractmethod
    def log_inpaints(self, batch_idx, batch_imgs):
        ...


class ImageUtilsLogger(UtilsLogger):

    def __init__(self):
        super().__init__()

    def log_original(self, batch_idx, batch_imgs):
        log_imgs(batch_idx, batch_imgs, 'original')

    def log_original_mask_overlay(self, batch_idx, batch_imgs, batch_masks):
        log_img_mask_overlay(batch_masks, batch_imgs)

    def log_attr_maps(self, batch_idx, batch_masks):
        log_attr_maps(batch_idx, batch_masks)

    def log_attr_maps_post(self, batch_idx, batch_masks):
        log_attr_maps(batch_idx, batch_masks)

    def log_inpaints(self, batch_idx, batch_imgs):
        log_imgs(batch_idx, batch_imgs, 'inpaints')



class VideoUtilsLogger(UtilsLogger):
    
    def __init__(self):
        super().__init__()

    def log_original(self, batch_idx, batch_imgs):
        log_videos(batch_idx, batch_imgs, 'original')

    def log_original_mask_overlay(self, batch_idx, batch_imgs, batch_maps_post):
        log_attr_maps_3d_overlay(batch_idx, batch_imgs, batch_maps_post)

    def log_attr_maps(self, batch_idx, batch_masks):
        log_attr_maps_3d(batch_idx, batch_masks)

    def log_attr_maps_post(self, batch_idx, batch_masks):
        print('log_attr_maps_post: Not implemented for videos')

    def log_inpaints(self, batch_idx, batch_imgs):
        log_videos(batch_idx, batch_imgs, 'inpaints')

