import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import wandb
from jaxtyping import Float, Int, jaxtyped

from typing import List

from utils.Logging.LogStrategy import LogStrategy
from utils.utils import normalize_value_range, undersample_list, broadcast_modes_to_pixels_shape, get_hack_mode
from beartype import beartype as typechecker
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from utils.utils import display_tensor, display_mask


class LogBlender1(LogStrategy):
    """
    Used when generating blender sample 1 domain at a time.
    """
    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_train(
        self,
        plMod: pl.LightningModule,
        prompt_img: str,
        prediction: Float[torch.Tensor, 'batch 3 64 64'],
        input_to_model: Float[torch.Tensor, 'batch 3 64 64'],
        batch: Float[torch.Tensor, 'batch 3 64 64'],
        x0_hat: Float[torch.Tensor, 'batch 3 64 64'],
        ts: Int[torch.Tensor, 'batch'],
        batch_idx: int,
    ) -> None:
        """
        Log images for the step.
        """
        t_list = ts.tolist()
        images = torch.cat([batch, input_to_model, prediction, x0_hat], dim=-1)

        remaining = self.already_logged(plMod, batch_idx, batch_size=images.shape[0])
        if remaining <= 0:
            return
        images = images[:remaining]
        images = normalize_value_range(images, plMod.params.logging.value_range, clip=True)

        plMod.logger.experiment.log({f'{prompt_img}/image': [wandb.Image(image, caption=f'[batch, input_to_model, prediction, x0_hat]') for image in images]})

    @rank_zero_only
    def log_generate(self, stage_prefix: str, plMod: pl.LightningModule, batch_idx: int, batch, idx=None):
        """
        Log image for the generation
        """
        remaining = self.already_logged(plMod, batch_idx, batch_size=batch.shape[0])

        if remaining <= 0:
            return

        with torch.no_grad():
            _, generated_data, generated_x0 = plMod.generate_samples(remaining, return_samples=True, return_pred_x0=True)

        generated_data, generated_x0 = zip(*undersample_list(
            list(zip(generated_data, generated_x0)),
            plMod.params.logging.time_step_in_process,
            plMod.params.logging.strategy,
            plMod.params.logging.quad_factor,
        ))

        full_tensor    = torch.cat(generated_data, dim=-1)
        full_tensor_x0 = torch.cat(generated_x0  , dim=-1)
        # batch_size, channels, height, width * time_steps

        full_tensor = torch.cat([full_tensor, full_tensor_x0], dim=-2)  # concat them in height dimension

        full_tensor = normalize_value_range(full_tensor, plMod.params.logging.value_range, clip=True)

        plMod.logger.experiment.log({f'{stage_prefix}/image': [wandb.Image(image, caption=f'xt and x0_hat') for image in full_tensor]})


class LogBlender3(LogStrategy):
    """
    Used when generating blender sample 3 domain at a time.
    """
    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_train(
        self,
        stage_prefix  : str,
        plMod         : pl.LightningModule,
        ts            : Int[torch.Tensor, 'b n_dom'],
        modes         : Float[torch.Tensor, 'b n_dom'],
        batch         : Float[torch.Tensor, 'b 9 h w'],
        input_to_model: Float[torch.Tensor, 'b 9 h w'],
        prediction    : Float[torch.Tensor, 'b 9 h w'],
        x0_hat        : Float[torch.Tensor, 'b 9 h w'],
        batch_idx     : int,
    ) -> None:
        """
        Log image for the train
        """
        b, _, h, w = batch.shape
        c = 3
        n_dom = 3

        remaining = self.already_logged(plMod, batch_idx, batch_size=b)
        if remaining <= 0:
            return

        images = torch.cat([batch, input_to_model, prediction, x0_hat], dim=-1)
        mode_per_pixels = broadcast_modes_to_pixels_shape(b, n_dom, c, h, w, modes)
        masks  = torch.cat([mode_per_pixels] * 4, dim=-1).long()

        images = normalize_value_range(images, plMod.params.logging.value_range, clip=True)

        current_quantity = min(remaining, images.shape[0])
        images = images[:remaining]
        masks  = masks[:remaining]

        # the 3 domains are cat on the channel dims, they should be cat on the h dim
        images = images.reshape([-1, 3, 3, images.shape[-2], images.shape[-1]])
        images = images.transpose(1, 2).reshape([-1, 3, 3 * images.shape[-2], images.shape[-1]])

        masks = masks.reshape([-1, 3, 3, masks.shape[-2], masks.shape[-1]])
        masks = masks.transpose(1, 2).reshape([-1, 3, 3 * masks.shape[-2], masks.shape[-1]])

        images_list = []
        for i in range(images.shape[0]):
            image = images[i]
            mask  = masks[i, 0]
            wandb_image = wandb.Image(
                image,
                caption=f'data, D(data), model(D(data)) rinsed, x0_hat',
                masks={
                    'supervision': {"mask_data": mask.detach().cpu().numpy(), "class_labels": {0: "unavailable", 1: "available", 2: "padding"}},
                }
            )
            images_list.append(wandb_image)
        plMod.logger.experiment.log({f'{stage_prefix}/image': images_list})

    @jaxtyped
    @typechecker
    def log_generate(
        self,
        stage_prefix: str,
        plMod       : pl.LightningModule,
        batch_idx   : int,
        batch       : List[Float[torch.Tensor, 'b 3 h w']],
        modes       : Float[torch.Tensor, 'b 3'],
        idx=None,
    ):
        """
        Log image for the generation and compute metric on generation
        """
        # decide of the mode for the generation current batch
        if plMod.params.logging.hack_mode is not None:
            modes = get_hack_mode(hack_mode=plMod.params.logging.hack_mode, modes_init=modes)

        # generate samples
        _, generated_data, generated_x0 = plMod.generate_samples(
            modes=modes,
            condition=batch,

            # under sample in the time steps
            undersampling=plMod.params.logging.time_step_in_process,
            strategy=plMod.params.logging.strategy,
            quad_factor=plMod.params.logging.quad_factor,
        )

        # log metrics
        self.log_generate_metrics(
            stage_prefix=stage_prefix,
            plMod=plMod,
            batch=batch,
            modes=modes,
            generated_x0=generated_x0,
        )

        # log images
        self.log_generate_log_images(
            plMod=plMod,
            stage_prefix=stage_prefix,
            batch_idx=batch_idx,
            batch=batch,
            modes=modes,
            generated_data=generated_data,
            generated_x0=generated_x0,
        )

    @jaxtyped
    @typechecker
    def log_generate_metrics(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch: List[Float[torch.Tensor, 'b 3 h w']],
        modes: Float[torch.Tensor, 'b 3'],

        generated_x0: List[Float[torch.Tensor, 'b 9 h w']],
    ) -> None:
        batch_size, _, h, w = batch[0].shape

        generated_prediction = [generated_x0[-1][:, 0:3], generated_x0[-1][:, 3:6], generated_x0[-1][:, 6:9]]

        metrics_obs = plMod.get_metric_object()
        metric_dict = metrics_obs.get_dict_generation(data=batch, prediction=generated_prediction, mode=modes)
        for metric_name, value in metric_dict.items():
            plMod.log_g(stage_prefix, metric_name, value)

    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_generate_log_images(
        self,
        stage_prefix: str,
        plMod       : pl.LightningModule,
        batch_idx   : int,
        batch       : List[Float[torch.Tensor, 'b 3 h w']],
        modes       : Float[torch.Tensor, 'b 3'],

        generated_data: List[Float[torch.Tensor, 'b 9 h w']],
        generated_x0  : List[Float[torch.Tensor, 'b 9 h w']],
    ) -> None:
        """
        Log image for the generation
        """
        remaining = self.already_logged(plMod, batch_idx, batch_size=batch[0].shape[0])
        if remaining <= 0:
            return

        # remove useless data
        batch = [b[:remaining] for b in batch]
        modes = modes[:remaining]
        generated_data = [g[:remaining] for g in generated_data]
        generated_x0 = [g[:remaining] for g in generated_x0]

        batch_size, _, h, w = batch[0].shape
        batch_cat = torch.cat(batch, dim=1)

        # concat on the timestep dimension
        time_steps = len(generated_data)
        x_s = torch.cat(generated_data, dim=-1)
        x0_hat = torch.cat(generated_x0, dim=-1)

        # region fetch L1 map.
        # Input are in [-1,1] and output are unbounded, so we need to normalize the L1 map
        # we will put it in [-1, 1].unbounded in order to respect the other format
        l1_map = F.l1_loss(batch_cat.repeat(1, 1, 1, time_steps), x0_hat.clamp(-1, 1), reduction='none') \
                     .reshape(batch_size, 3, 3, h, w*time_steps) \
                     .mean(dim=2, keepdim=True) \
                     .repeat(1, 1, 3, 1, 1) \
                     .reshape(batch_size, 3*3, h, w*time_steps)
        l1_map = l1_map - 1  # min was 0, and max 2
        # endregion

        # region at GT at the end of the line
        supervision_black_strip_width = 10
        black_strip = torch.full([batch_size, x_s.shape[1], x_s.shape[2], supervision_black_strip_width], fill_value=-1, device=x_s.device)
        x_s = torch.cat([x_s, black_strip, batch_cat], dim=-1)
        x0_hat = torch.cat([x0_hat, black_strip, batch_cat], dim=-1)
        l1_map = torch.cat([l1_map, black_strip, batch_cat], dim=-1)
        # endregion

        # concat them in height dimension x_s and x0_hat
        img_cat = torch.cat([x_s, x0_hat, l1_map], dim=-2)

        # region add supervision strip on the left
        strip_color_width = 10
        strip_space = 10
        strip_width = strip_color_width + strip_space
        strip_shape = [batch_size, 9, img_cat.shape[2], strip_width]
        strip = torch.full(strip_shape, fill_value=-1, device=img_cat.device)  # fill with black in [-1, 1]
        for bi in range(batch_size):
            for dom_i in range(modes.shape[1]):
                is_supervised = modes[bi, dom_i] == 1
                color = torch.tensor([-1, 1, -1]) if is_supervised else torch.tensor([1, -1, -1])  # red and green in [-1, 1]
                strip_i = color.reshape(3, 1, 1).repeat(1, img_cat.shape[2], strip_color_width)
                strip[bi, dom_i * 3:dom_i * 3 + 3, :, :strip_color_width] = strip_i

        # concat the strip with the initial datas
        img_cat = torch.cat([strip, img_cat], dim=-1)

        # endregion

        # the 3 domains are cat on the channel dims, they should be cat on the h dim
        img_cat = img_cat.reshape([-1, 3, 3, img_cat.shape[-2], img_cat.shape[-1]])
        img_cat = img_cat.transpose(1, 2).reshape([-1, 3, 3 * img_cat.shape[-2], img_cat.shape[-1]])

        img_cat = normalize_value_range(img_cat, plMod.params.logging.value_range, clip=True)
        wandb_images = [wandb.Image(
            img_cat[i],
            caption=f'{batch_idx=} xt and x0_hat',
        ) for i in range(img_cat.shape[0])]
        plMod.logger.experiment.log({f'{stage_prefix}/image': wandb_images})
