import itertools
from typing import List, Optional

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import wandb

from jaxtyping import jaxtyped, Int, Float
from beartype import beartype as typechecker
from lightning_utilities.core.rank_zero import rank_zero_only
from torchvision.utils import save_image

from conf.dataset import BRATS2020Params
from utils.Logging.LogStrategy import LogStrategy
from utils.utils import normalize_value_range, mask2rgb, augmentWithBackground, get_hack_mode

from utils.utils import display_tensor, display_mask


class LogBRATS2020(LogStrategy):
    @jaxtyped
    @typechecker
    def scan_to_rgb(self, val: Float[torch.Tensor, 'b n_dom h w']) -> Float[torch.Tensor, 'b n_dom*3 h w']:
        """
        duplicate the channels of each domain to make it 3
        """
        b, n_dom, h, w = val.shape
        val = val.reshape([b, n_dom, 1, h, w]).repeat(1, 1, 3, 1, 1).reshape([b, n_dom * 3, h, w])
        return val

    @jaxtyped
    @typechecker
    def from_data_get_seg(self, data: Float[torch.Tensor, 'b c_n_dom h w']) -> Int[torch.Tensor, 'b h w']:
        """
        get the categorical segmentation from the full data
        """
        b, _, h, w = data.shape
        n_dom_scan = self.params_data.data_params.n_dom_scan
        scan_mode = self.params_data.data_params.segmentation_mode
        seg = data[:, n_dom_scan:]
        if scan_mode == 1:
            return (seg > 0.5).int().squeeze(1)

        elif scan_mode == 2:
            return seg.argmax(dim=1)

        elif scan_mode == 3:
            seg = augmentWithBackground(segmentations_maps=seg)
            return seg.argmax(dim=1)

        elif scan_mode == 4:
            return seg.argmax(dim=1)

        else:
            raise Exception()

    @jaxtyped
    @typechecker
    def data_to_rgb(self, data: Float[torch.Tensor, 'b c h w']) -> Float[torch.Tensor, 'b 3*5 h w']:
        data_rgb = torch.cat([
            self.scan_to_rgb(data[:, :4]),
            mask2rgb(self.from_data_get_seg(data)),
        ], dim=1)

        return data_rgb

    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_train(
        self,
        stage_prefix  : str,
        plMod         : pl.LightningModule,
        ts            : Int[torch.Tensor  , 'b n_dom'],
        modes         : Float[torch.Tensor, 'b n_dom'],
        batch         : Float[torch.Tensor, 'b c_n_dom h w'],
        input_to_model: Float[torch.Tensor, 'b c_n_dom h w'],
        prediction    : Float[torch.Tensor, 'b c_n_dom h w'],
        x0_hat        : Float[torch.Tensor, 'b c_n_dom h w'],
        batch_idx     : int,
    ) -> None:
        """
        Log image during training step
        """
        b, _, h, w = batch.shape
        n_dom = BRATS2020Params.n_dom
        rgb_dim = 3
        paramsBRATS: BRATS2020Params = self.params_data.data_params
        n_dom_scan = paramsBRATS.n_dom_scan

        remaining = self.already_logged(plMod, batch_idx, batch_size=b)
        if remaining <= 0:
            return

        images = torch.cat([batch, input_to_model, prediction, x0_hat], dim=-1)
        images[:, :n_dom_scan] = normalize_value_range(images[:, :n_dom_scan], value_range=paramsBRATS.value_range, clip=True)  # only normalize the scan, not the segmentation

        # transpose each domain into their RGB channels
        seg_categorical = self.from_data_get_seg(images)
        images = torch.cat([
            self.scan_to_rgb(images[:, :n_dom_scan]),
            mask2rgb(seg_categorical),
        ], dim=1)
        images = images[:remaining]

        # the domains are cat on the channel dims, they should be cat on the h dim
        images = images.reshape([-1, n_dom, rgb_dim, images.shape[-2], images.shape[-1]])
        images = images.transpose(1, 2).reshape([-1, rgb_dim, n_dom * images.shape[-2], images.shape[-1]])
        images = images.clamp(0, 1)

        plMod.logger.experiment.log({f'{stage_prefix}/image': [
            wandb.Image(
                image,
                caption=f'data, D(data), model(D(data)) rinsed, x0_hat',
            )
            for image in images
        ]})

    @jaxtyped
    @typechecker
    def log_generate(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch_idx: int,
        batch: List[Float[torch.Tensor, 'b c h w']],
        modes: Float[torch.Tensor, 'b n_dom'],
        idx=None,
    ):
        """
        Log image and compute metrics
        """
        if self.params.early_leave:
            remaining = self.already_logged(plMod, batch_idx, batch_size=batch[0].shape[0])
            if remaining <= 0:
                return

        # decide of the mode for the generation current batch
        if plMod.params.logging.hack_mode is not None:
            modes = get_hack_mode(hack_mode=plMod.params.logging.hack_mode, modes_init=modes)

        # generate samples
        with torch.no_grad():
            _, generated_data, generated_x0 = plMod.generate_samples(
                modes=modes,
                condition=[i for i in batch],

                # under sample in the time steps
                undersampling=plMod.params.logging.time_step_in_process,
                strategy=plMod.params.logging.strategy,
                quad_factor=plMod.params.logging.quad_factor,
            )

        # generate metrics
        self.log_generate_metrics(
            stage_prefix=stage_prefix,
            plMod=plMod,
            batch=batch,
            modes=modes,
            generated_x0=generated_x0,
            idx=idx,
        )

        # log images

        images = self.log_generate_log_images(
            stage_prefix=stage_prefix,
            plMod=plMod,
            batch_idx=batch_idx,
            batch=batch,
            modes=modes,
            generated_data=generated_data,
            generated_x0=generated_x0,
            idx=idx,
        )
        self.log_generate_to_wandb(
            stage_prefix=stage_prefix,
            batch_idx=batch_idx,
            img_cat=images,
            plMod=plMod,
            idx=idx,
        )
        self.log_image_to_disk(
            img_cat=images,
            plMod=plMod,
            idx=idx,
        )

    @jaxtyped
    @typechecker
    def log_generate_metrics(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch: List[Float[torch.Tensor, 'b ci h w']],
        modes: Float[torch.Tensor, 'b n_dom'],

        generated_x0: List[Float[torch.Tensor, 'b n_dom_c h w']],
        idx=None,
    ) -> None:
        batch_size, _, h, w = batch[0].shape

        # split the generation in a list of domain
        cumulative_channels = list(itertools.accumulate(self.params_data.data_params.dimension_per_domain))

        generated_prediction = [generated_x0[-1][:, i:i2] for i, i2 in zip([0] + cumulative_channels, cumulative_channels)]

        metrics_obs = plMod.get_metric_object()
        metric_dict = metrics_obs.get_dict_generation(
            data=batch,
            prediction=generated_prediction,
            mode=modes,
        )
        for metric_name, value in metric_dict.items():
            plMod.log_g(stage_prefix, metric_name, value)

    @jaxtyped
    @typechecker
    def log_generate_log_images(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch_idx: int,
        batch: List[Float[torch.Tensor, 'b ci h w']],  # length of the list is the number of domains
        modes: Float[torch.Tensor, 'b n_dom'],

        generated_data: List[Float[torch.Tensor, 'b n_dom_c h w']],  # the length of the list is the number of time steps
        generated_x0  : List[Float[torch.Tensor, 'b n_dom_c h w']],  # the length of the list is the number of time steps
        idx=None,
    ) -> Optional[Float[torch.Tensor, 's 3 hg wg']]:
        paramsBRATS: BRATS2020Params = self.params_data.data_params
        n_dom_scan = paramsBRATS.n_dom_scan
        n_dom = paramsBRATS.n_dom
        dim_rgb = 3
        batch_size, _, h, w = batch[0].shape
        time_steps = len(generated_data)

        # remove useless data and cat along the time steps when needed
        batch  = torch.cat(batch, dim=1)[:batch_size]  # cat along channels
        modes  = modes[:batch_size]
        x_s    = torch.cat(generated_data, dim=-1)[:batch_size]  # cat along the time steps
        x0_hat = torch.cat(generated_x0, dim=-1)[:batch_size]    # cat along the time steps

        # region normalize the value range -> only normalize the scan, not the segmentation, clip both to [0, 1]
        batch[:, :4]  = normalize_value_range(batch[:, :4] , plMod.params.logging.value_range, clip=True)
        x_s[:, :4]    = normalize_value_range(x_s[:, :4]   , plMod.params.logging.value_range, clip=True)
        x0_hat[:, :4] = normalize_value_range(x0_hat[:, :4], plMod.params.logging.value_range, clip=True)
        # endregion

        # region FETCH THE L1 MAP

        # to have a good visualization of the l1 for the segmentation, we will put the segmentation in category
        x0_hat_for_l1 = torch.cat([
            x0_hat[:, :4],
            self.from_data_get_seg(x0_hat).unsqueeze(dim=1),  # encode segmentation in categorical for the l1, after clamp it's ok
        ], dim=1)
        # and we do the same for the batch to be able to compare
        batch_for_l1 = torch.cat([
            batch[:, :4],
            self.from_data_get_seg(batch).unsqueeze(dim=1),
        ], dim=1)

        l1_map = F.l1_loss(batch_for_l1.repeat(1, 1, 1, time_steps), x0_hat_for_l1, reduction='none')
        # l1 map is [batch, c_ndom, h, w * time_steps], we have the error for the segmentation for each class on the 3 dim of C
        # we will just put the sum clamped at one
        l1_map = torch.cat([
            l1_map[:, :4],
            l1_map[:, 4:].sum(dim=1, keepdim=True).clamp(0., 1.),
        ], dim=1)

        l1_map = l1_map.unsqueeze(2).repeat(1, 1, dim_rgb, 1, 1)  # put in rgb
        l1_map[:, :, 1:] = 0  # l1 map is red
        # l1 map is [batch, c_ndom, rgb, h, w * time_steps]
        l1_map = l1_map.reshape(batch_size, -1, h, w * time_steps)
        # endregion

        # region put data as rgb
        batch  = self.data_to_rgb(batch)
        x_s    = self.data_to_rgb(x_s)
        x0_hat = self.data_to_rgb(x0_hat)
        # endregion

        # region at GT at the end of the line
        supervision_black_strip_width = 10
        black_strip = torch.zeros([batch_size, x_s.shape[1], x_s.shape[2], supervision_black_strip_width],
                                  device=x_s.device)
        x_s = torch.cat([x_s, black_strip, batch], dim=-1)
        x0_hat = torch.cat([x0_hat, black_strip, batch], dim=-1)
        l1_map = torch.cat([l1_map, black_strip, batch], dim=-1)
        # endregion

        # concat them in height dimension x_s and x0_hat
        img_cat = torch.cat([x_s, x0_hat, l1_map], dim=-2)

        # region add supervision strip on the left
        strip_color_width = 10
        strip_space = 10
        strip_width = strip_color_width + strip_space
        strip_shape = [batch_size, n_dom * dim_rgb, img_cat.shape[2], strip_width]
        strip = torch.zeros(strip_shape, device=img_cat.device)
        for bi in range(batch_size):
            for dom_i in range(modes.shape[1]):
                is_supervised = modes[bi, dom_i] == 1
                color = torch.tensor([0, 1, 0]) if is_supervised else torch.tensor([1, 0, 0])
                strip_i = color.reshape(dim_rgb, 1, 1).repeat(1, img_cat.shape[2], strip_color_width)
                strip[bi, dom_i * dim_rgb:dom_i * dim_rgb + dim_rgb, :, :strip_color_width] = strip_i

        # concat the strip with the initial datas
        img_cat = torch.cat([strip, img_cat], dim=-1)
        # endregion

        # the domains are cat on the channel dims, they should be cat on the h dim
        img_cat = img_cat.reshape([batch_size, n_dom, dim_rgb, img_cat.shape[-2], img_cat.shape[-1]])
        img_cat = img_cat.transpose(1, 2).reshape([-1, dim_rgb, n_dom * img_cat.shape[-2], img_cat.shape[-1]])

        return img_cat

    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_generate_to_wandb(
        self,
        stage_prefix: str,
        batch_idx: int,
        img_cat: Float[torch.Tensor, 'b 3 hg wg'],
        plMod,
        idx=None,
    ) -> None:
        batch_size = img_cat.shape[0]
        remaining = self.already_logged(plMod, batch_idx, batch_size=batch_size)
        if remaining <= 0:
            return
        remaining = min(remaining, batch_size)

        wandb_images = [wandb.Image(
            img_cat[i],
            caption=f'id:{idx[i]} xt and x0_hat' if idx is not None else 'xt and x0_hat',
        ) for i in range(remaining)]
        plMod.logger.experiment.log({f'{stage_prefix}/image': wandb_images})

    @jaxtyped
    @typechecker
    def log_image_to_disk(
        self,
        img_cat: Float[torch.Tensor, 'b 3 hg wg'],
        plMod,
        idx=None,
    ) -> None:
        # save image to disk if needed
        if plMod.get_stage() == 'test' and plMod.params.logging.save_image_to_disk:
            assert idx is not None, 'idx should be provided when saving image to disk'
            for i in range(img_cat.shape[0]):
                img_i = img_cat[i]
                save_image(img_i, fp=f'_results/{idx[i]}.png')
                if self.params.log_pt:
                    torch.save(img_i, f'_results/{idx[i]}.pt')
