from typing import Dict, List

import pytorch_lightning as pl
import torch
import torch.nn as nn
from beartype import beartype as typechecker
from jaxtyping import Float, Int, jaxtyped
from torchmetrics import (JaccardIndex, MeanAbsoluteError, MeanSquaredError,
                          PeakSignalNoiseRatio)
from torchmetrics.image import (LearnedPerceptualImagePatchSimilarity,
                                StructuralSimilarityIndexMeasure)

from conf.dataset import (BRATS2020Params, CelebAParams, DatasetParams,
                          ValueRange)
from conf.model import MetricsParams, ModelParams
from utils.MetricsImplem.DiversityMetric import DiversityMetric
from utils.utils import augmentWithBackground, display_mask, display_tensor


class CelebA3Metrics(pl.LightningModule):
    def __init__(self, params: ModelParams, params_data: DatasetParams):
        super().__init__()
        self.params = params
        self.celeba_params: CelebAParams = params_data.data_params

        # training
        self.step_supervised_mae = MeanAbsoluteError()
        self.step_unsupervised_mae = MeanAbsoluteError()

        # GENERATION
        # face metrics
        self.lpips_clamp_face = LearnedPerceptualImagePatchSimilarity(
            net_type='alex',
            reduction='mean',
            normalize=False,  # image are in [-1,1]
        )
        self.ssim_clamp_face = StructuralSimilarityIndexMeasure(data_range=(-1., 1.))

        # sketch metrics
        self.psnr_clamp = PeakSignalNoiseRatio(data_range=(-1., 1.), base=10., reduction='elementwise_mean', dim=None)
        self.mae_clamp  = MeanAbsoluteError()
        self.mse_clamp  = MeanSquaredError()
        self.ssim_clamp = StructuralSimilarityIndexMeasure(data_range=(-1., 1.))
        self.lpips_clamp_sketch = LearnedPerceptualImagePatchSimilarity(
            net_type='alex',
            reduction='mean',
            normalize=False,  # image are in [-1,1]
        )

        # segmentation metrics
        self.gen_jaccards = JaccardIndex(task="multiclass", num_classes=self.celeba_params.n_class)

        # GENERATION DIVERSITY
        self.diversity = DiversityMetric()

    def compute_and_get(self):
        res = dict()
        res |= {f'lpips_clamp_face': self.lpips_clamp_face.compute()}
        res |= {f'ssim_clamp_face': self.ssim_clamp_face.compute()}

        res |= {f'psnr_clamp': self.psnr_clamp.compute()}
        res |= {f'mae_clamp': self.mae_clamp.compute()}
        res |= {f'mse_clamp': self.mse_clamp.compute()}
        res |= {f'ssim_clamp': self.ssim_clamp.compute()}
        res |= {f'lpips_clamp_sketch': self.lpips_clamp_sketch.compute()}

        res |= {f'jaccard': self.gen_jaccards.compute()}
        return res

    @jaxtyped
    @typechecker
    def get_dict_generation_diversity(
        self,
        batch: List[Float[torch.Tensor, 'b diversity ci h w']],
        prediction: List[Float[torch.Tensor, 'b diversity ci h w']],
        modes: Float[torch.Tensor, 'b diversity n_dom'],
    ):
        modes_face = modes[:, 0, 0].bool()  # bcs same on diversity dim
        target_face = batch[0].clamp(-1, 1)
        target_face = target_face[~modes_face]
        pred_face = prediction[0].clamp(-1, 1)
        pred_face = pred_face[~modes_face]

        self.diversity.update(preds=pred_face, target=target_face)

        return {
            'diversity_face': self.diversity,
        }

    @jaxtyped
    @typechecker
    def segmentation_to_categorical(self, seg: Float[torch.Tensor, 'b c h w'], return_background: int) -> Int[torch.Tensor, 'b h w']:
        """
        get the categorical segmentation from the segmentation
        """
        if not return_background:
            seg = augmentWithBackground(segmentations_maps=seg)
            return seg.argmax(dim=1)
        else:
            return seg.argmax(dim=1)

    @jaxtyped
    @typechecker
    def get_dict_generation(
        self, *,
        data: List[Float[torch.Tensor, 'b c h w']],
        prediction: List[Float[torch.Tensor, 'b c h w']],
        mode: Float[torch.Tensor, 'b n_dom'],
    ) -> Dict:
        data_cat       = torch.cat(data      , dim=1).to(self.device)
        prediction_cat = torch.cat(prediction, dim=1).to(self.device)

        mode_bool = mode.bool().to(self.device)

        dim_mask = data_cat.shape[1] - 4
        target_photo, target_sketch, target_masks = data_cat.split([3, 1, dim_mask], dim=1)
        prediction_photo, prediction_sketch, prediction_masks = prediction_cat.split([3, 1, dim_mask], dim=1)

        # region FACE METRICS
        self.lpips_clamp_face.update(img1=prediction_photo.clamp(-1, 1), img2=target_photo.clamp(-1, 1))
        self.ssim_clamp_face.update(preds=prediction_photo.clamp(-1, 1), target=target_photo.clamp(-1, 1))
        # endregion

        # region SKETCH METRICS
        self.psnr_clamp.update(preds=prediction_sketch.clamp(-1, 1), target=target_sketch)
        self.mae_clamp.update(preds=prediction_sketch.clamp(-1, 1), target=target_sketch.clamp(-1, 1))
        self.mse_clamp.update(preds=prediction_sketch.clamp(-1, 1), target=target_sketch.clamp(-1, 1))
        self.ssim_clamp.update(preds=prediction_sketch, target=target_sketch)
        self.lpips_clamp_sketch.update(img1=prediction_sketch.repeat(1, 3, 1, 1), img2=target_sketch.repeat(1, 3, 1, 1))
        # endregion

        # region SEGMENTATION METRICS
        mode_bool_seg = mode_bool[:, 2:]

        return_background = self.celeba_params.return_background
        target_masks     = self.segmentation_to_categorical(seg=target_masks, return_background=return_background)
        prediction_masks = self.segmentation_to_categorical(seg=prediction_masks, return_background=return_background)

        prediction_masks_unsupervised = prediction_masks.unsqueeze(1)[~mode_bool_seg]
        target_masks_unsupervised = target_masks.unsqueeze(1)[~mode_bool_seg]

        self.gen_jaccards.update(preds=prediction_masks_unsupervised, target=target_masks_unsupervised)
        # endregion

        res = dict()
        res |= {f'lpips_clamp_face': self.lpips_clamp_face}
        res |= {f'ssim_clamp_face': self.ssim_clamp_face}

        res |= {f'psnr_clamp': self.psnr_clamp}
        res |= {f'mae_clamp': self.mae_clamp}
        res |= {f'mse_clamp': self.mse_clamp}
        res |= {f'ssim_clamp': self.ssim_clamp}
        res |= {f'lpips_clamp_sketch': self.lpips_clamp_sketch}

        res |= {f'jaccard': self.gen_jaccards}
        return res

    @jaxtyped
    @typechecker
    def get_dict(
        self, *,
        data       : List[Float[torch.Tensor, 'b _chan_per_dom h w']],
        mode       : Float[torch.Tensor, 'b n_dom      '],
        batch_mixed: Float[torch.Tensor, 'b c_x_dom h w'],
        noise      : Float[torch.Tensor, 'b c_x_dom h w'],
        batch_recon: Float[torch.Tensor, 'b c_x_dom h w'],
        data0      : Float[torch.Tensor, 'b c_x_dom h w'],
    ) -> Dict:
        b, n_dom = mode.shape

        predict_noise = self.params.loss.predict_noise

        if predict_noise:
            target = noise
            prediction = batch_recon
        else:
            target = torch.cat(data, dim=1)
            prediction = data0

        # cast mode from [b ndom] to [b c_ndom]
        c_per_dom = self.celeba_params.dimension_per_domain
        mode = torch.cat([
            mode[:, i:i+1].repeat(1, c_per_dom[i]) for i in range(n_dom)
        ], dim=1)

        mode_bool = mode.bool()

        supervised_prediction = prediction[mode_bool]
        supervised_target = target[mode_bool]
        unsupervised_prediction = prediction[~mode_bool]
        unsupervised_target = target[~mode_bool]

        self.step_supervised_mae.update(preds=supervised_prediction, target=supervised_target)
        self.step_unsupervised_mae.update(preds=unsupervised_prediction, target=unsupervised_target)

        return {
            'supervised_mae': self.step_supervised_mae,
            'unsupervised_mae': self.step_unsupervised_mae,
        }


class CelebAMetrics1(pl.LightningModule):
    def __init__(self, params: ModelParams, params_data: DatasetParams):
        super().__init__()
        self.params = params

        # training
        self.mae = MeanAbsoluteError()

    @jaxtyped
    @typechecker
    def get_dict(
        self, *,
        data       : Float[torch.Tensor, 'b c h w'],
        batch_mixed: Float[torch.Tensor, 'b c h w'],
        noise      : Float[torch.Tensor, 'b c h w'],
        batch_recon: Float[torch.Tensor, 'b c h w'],
        data0      : Float[torch.Tensor, 'b c h w'],
    ) -> Dict:
        predict_noise = self.params.loss.predict_noise

        if predict_noise:
            target = noise
            prediction = batch_recon
        else:
            target = data
            prediction = data0

        self.mae.update(preds=prediction, target=target)

        return {
            'mae': self.mae,
        }


class Blender1Metrics(pl.LightningModule):
    def __init__(self, params: ModelParams, params_data: DatasetParams):
        super().__init__()
        self.params = params

        # training
        self.mae = MeanAbsoluteError()

    @jaxtyped
    @typechecker
    def get_dict(
        self, *,
        data       : Float[torch.Tensor, 'b c h w'],
        batch_mixed: Float[torch.Tensor, 'b c h w'],
        noise      : Float[torch.Tensor, 'b c h w'],
        batch_recon: Float[torch.Tensor, 'b c h w'],
        data0      : Float[torch.Tensor, 'b c h w'],
    ) -> Dict:
        predict_noise = self.params.loss.predict_noise

        if predict_noise:
            target = noise
            prediction = batch_recon
        else:
            target = data
            prediction = data0

        self.mae.update(preds=prediction, target=target)

        return {
            'mae': self.mae,
        }


class BRATS2020_Metrics(pl.LightningModule):
    def __init__(self, params: ModelParams, params_data: DatasetParams):
        super().__init__()
        self.params = params
        self.params_data = params_data
        self.brats2020_params: BRATS2020Params = params_data.data_params
        if self.brats2020_params.value_range == ValueRange.Zero:
            self.val_min, self.val_max = 0., 1.
        elif self.brats2020_params.value_range == ValueRange.One:
            self.val_min, self.val_max = -1., 1.
        else:
            raise Exception(f"ValueRange {self.brats2020_params.value_range} not supported")

        # for step training
        self.step_supervised_mae = MeanAbsoluteError()
        self.step_unsupervised_mae = MeanAbsoluteError()
        self.step_supervised_mse = MeanSquaredError()
        self.step_unsupervised_mse = MeanSquaredError()

        # for generation
        self.full_psnr_clamp = PeakSignalNoiseRatio(
            data_range=(self.val_min, self.val_max),
            base=10.,
            reduction='elementwise_mean',
            dim=None
        )
        self.mae_full_clamp = MeanAbsoluteError()
        self.mse_full_clamp = MeanSquaredError()
        self.ssim_full_clamp = StructuralSimilarityIndexMeasure(data_range=(self.val_min, self.val_max))

        self.gen_psnr_clamp = nn.ModuleList([PeakSignalNoiseRatio(data_range=(self.val_min, self.val_max), base=10., reduction='elementwise_mean', dim=None) for _ in range(self.brats2020_params.n_dom_scan)])
        self.gen_maes_clamp = nn.ModuleList([MeanAbsoluteError() for _ in range(self.brats2020_params.n_dom_scan)])
        self.gen_mses_clamp = nn.ModuleList([MeanSquaredError() for _ in range(self.brats2020_params.n_dom_scan)])
        self.gen_ssim_clamp = nn.ModuleList([StructuralSimilarityIndexMeasure(data_range=(self.val_min, self.val_max)) for _ in range(self.brats2020_params.n_dom_scan)])

        if self.brats2020_params.segmentation_mode in [1, 2]:
            task = "binary"
            num_classes = None
        elif self.brats2020_params.segmentation_mode in [3, 4]:
            task = "multiclass"
            num_classes = self.brats2020_params.n_class
        else:
            raise ValueError(f"segmentation_mode {self.brats2020_params.segmentation_mode} not supported")

        self.gen_jaccards = JaccardIndex(task=task, num_classes=num_classes)

    def compute_and_get(self):
        scan_names = self.brats2020_params.scan_names
        seg_names = self.brats2020_params.seg_names

        res = dict()
        res |= {f'psnr_{domain}_clamp': self.gen_psnr_clamp[i].compute() for i, domain in enumerate(scan_names)}
        # res |= {f'psnr_full_clamp': self.full_psnr_clamp}

        res |= {f'mae_{domain}_clamp': self.gen_maes_clamp[i].compute() for i, domain in enumerate(scan_names)}
        # res |= {f'mae_full_clamp': self.mae_full_clamp}
        res |= {f'mse_{domain}_clamp': self.gen_mses_clamp[i].compute() for i, domain in enumerate(scan_names)}
        # res |= {f'mse_full_clamp': self.mse_full_clamp}

        res |= {f'ssim_{domain}_clamp': self.gen_ssim_clamp[i].compute() for i, domain in enumerate(scan_names)}
        # res |= {f'ssim_full_clamp': self.ssim_full_clamp}

        res |= {f'jaccard': self.gen_jaccards.compute()}
        return res

    @jaxtyped
    @typechecker
    def from_data_get_seg(self, data: Float[torch.Tensor, 'b c_n_dom h w']) -> Int[torch.Tensor, 'b h w']:
        """
        get the categorical segmentation from the full data
        """
        b, _, h, w = data.shape
        n_dom_scan = self.params_data.data_params.n_dom_scan
        scan_mode = self.params_data.data_params.segmentation_mode
        seg = data[:, n_dom_scan:]
        seg = self.segmentation_to_categorical(seg=seg, scan_mode=scan_mode)
        return seg

    @jaxtyped
    @typechecker
    def segmentation_to_categorical(self, seg: Float[torch.Tensor, 'b c h w'], scan_mode: int) -> Int[torch.Tensor, 'b h w']:
        """
        get the categorical segmentation from the segmentation
        """
        if scan_mode == 1:
            return (seg > 0.5).int().squeeze(1)

        elif scan_mode == 2:
            return seg.argmax(dim=1)

        elif scan_mode == 3:
            seg = augmentWithBackground(segmentations_maps=seg)
            return seg.argmax(dim=1)

        elif scan_mode == 4:
            return seg.argmax(dim=1)

        else:
            raise Exception()


    @jaxtyped
    @typechecker
    def get_dict_generation(
        self, *,
        data: List[Float[torch.Tensor, 'b c h w']],
        prediction: List[Float[torch.Tensor, 'b c h w']],
        mode: Float[torch.Tensor, 'b n_dom'],
    ) -> Dict:
        scan_mode = self.params_data.data_params.segmentation_mode

        data_cat       = torch.cat(data      , dim=1).to(self.device)
        prediction_cat = torch.cat(prediction, dim=1).to(self.device)

        mode_bool = mode.bool().to(self.device)

        target_scans = data_cat[:, :self.brats2020_params.n_dom_scan]
        target_masks = data_cat[:, self.brats2020_params.n_dom_scan:]

        prediction_scans = prediction_cat[:, :self.brats2020_params.n_dom_scan]
        prediction_masks = prediction_cat[:, self.brats2020_params.n_dom_scan:]

        # region SCAN METRICS
        # compute metrics for the scans individually
        for i in range(self.brats2020_params.n_dom_scan):
            prediction_scan_i_clamp = prediction_scans[:, i][~mode_bool[:, i]].clamp(self.val_min, self.val_max)
            target_scan_i_clamp     = target_scans    [:, i][~mode_bool[:, i]].clamp(self.val_min, self.val_max)
            if target_scan_i_clamp.shape[0] == 0:
                continue

            self.gen_psnr_clamp[i].update(preds=prediction_scan_i_clamp, target=target_scan_i_clamp)
            self.gen_maes_clamp[i].update(preds=prediction_scan_i_clamp, target=target_scan_i_clamp)
            self.gen_mses_clamp[i].update(preds=prediction_scan_i_clamp, target=target_scan_i_clamp)
            self.gen_ssim_clamp[i].update(preds=prediction_scan_i_clamp.unsqueeze(1), target=target_scan_i_clamp.unsqueeze(1))

        # region SEGMENTATION METRICS
        mode_bool_seg = mode_bool[:, 4:]

        prediction_masks_unsupervised = prediction_masks.unsqueeze(1)[~mode_bool_seg]
        prediction_mask_catego = self.segmentation_to_categorical(prediction_masks_unsupervised, scan_mode=scan_mode)

        target_masks_unsupervised = target_masks.unsqueeze(1)[~mode_bool_seg]
        target_mask_catego = self.segmentation_to_categorical(target_masks_unsupervised, scan_mode=scan_mode)

        self.gen_jaccards.update(preds=prediction_mask_catego, target=target_mask_catego)
        # endregion

        scan_names = self.brats2020_params.scan_names
        seg_names = self.brats2020_params.seg_names

        res = dict()
        res |= {f'psnr_{domain}_clamp': self.gen_psnr_clamp[i] for i, domain in enumerate(scan_names)}
        # res |= {f'psnr_full_clamp': self.full_psnr_clamp}

        res |= {f'mae_{domain}_clamp': self.gen_maes_clamp[i] for i, domain in enumerate(scan_names)}
        # res |= {f'mae_full_clamp': self.mae_full_clamp}
        res |= {f'mse_{domain}_clamp': self.gen_mses_clamp[i] for i, domain in enumerate(scan_names)}
        # res |= {f'mse_full_clamp': self.mse_full_clamp}

        res |= {f'ssim_{domain}_clamp': self.gen_ssim_clamp[i] for i, domain in enumerate(scan_names)}
        # res |= {f'ssim_full_clamp': self.ssim_full_clamp}

        res |= {f'jaccard': self.gen_jaccards}
        return res

    @jaxtyped
    @typechecker
    def get_dict(
        self, *,
        data       : List[Float[torch.Tensor, 'b _chan_per_dom h w']],
        mode       : Float[torch.Tensor, 'b n_dom      '],
        batch_mixed: Float[torch.Tensor, 'b c_x_dom h w'],
        noise      : Float[torch.Tensor, 'b c_x_dom h w'],
        batch_recon: Float[torch.Tensor, 'b c_x_dom h w'],
        data0      : Float[torch.Tensor, 'b c_x_dom h w'],
    ) -> Dict:
        b, n_dom = mode.shape
        b, c_x_dom, h, w = batch_mixed.shape
        c = self.brats2020_params.dimension_per_domain

        predict_noise = self.params.loss.predict_noise

        if predict_noise:
            target = noise
            prediction = batch_recon
        else:
            target = torch.cat(data, dim=1)
            prediction = data0

        # cast mode from [b ndom] to [b c_ndom]
        c_per_dom = self.brats2020_params.dimension_per_domain
        mode = torch.cat([
            mode[:, i:i+1].repeat(1, c_per_dom[i]) for i in range(n_dom)
        ], dim=1)

        mode_bool = mode.bool()

        supervised_prediction = prediction[mode_bool]
        supervised_target = target[mode_bool]
        unsupervised_prediction = prediction[~mode_bool]
        unsupervised_target = target[~mode_bool]

        self.step_supervised_mae.update(preds=supervised_prediction, target=supervised_target)
        self.step_unsupervised_mae.update(preds=unsupervised_prediction, target=unsupervised_target)
        self.step_supervised_mse.update(preds=supervised_prediction, target=supervised_target)
        self.step_unsupervised_mse.update(preds=unsupervised_prediction, target=unsupervised_target)

        return {
            'supervised_mae': self.step_supervised_mae,
            'unsupervised_mae': self.step_unsupervised_mae,
            'supervised_mse': self.step_supervised_mse,
            'unsupervised_mse': self.step_unsupervised_mse,
        }


class Blender3Metrics(pl.LightningModule):
    def __init__(self, params: ModelParams, params_data: DatasetParams):
        super().__init__()
        self.params = params

        # training
        self.supervised_mae = MeanAbsoluteError()
        self.unsupervised_mae = MeanAbsoluteError()

        # generation
        self.generation_mae = MeanAbsoluteError()
        self.generation_mae_clamp = MeanAbsoluteError()

        self.generation_mae_clamp_1 = MeanAbsoluteError()
        self.generation_mae_clamp_2 = MeanAbsoluteError()
        self.generation_mae_clamp_3 = MeanAbsoluteError()

    @jaxtyped
    @typechecker
    def get_dict_generation(
        self, *,
        data: List[Float[torch.Tensor, 'b 3 h w']],
        prediction: List[Float[torch.Tensor, 'b 3 h w']],
        mode: Float[torch.Tensor, 'b 3'],
    ) -> Dict:
        b, c, h, w = data[0].shape
        b, n_dom = mode.shape
        data_cat = torch.cat(data, dim=1).reshape(b * n_dom, c, h, w)
        prediction_cat = torch.cat(prediction, dim=1).reshape(b * n_dom, c, h, w)

        mode_bool = mode.bool().reshape(-1)
        mode_bool_d = mode.bool()

        preds = prediction_cat[~mode_bool]
        target = data_cat[~mode_bool]

        self.generation_mae.update(preds=preds, target=target)
        self.generation_mae_clamp.update(preds=preds.clamp(-1, 1), target=target.clamp(-1, 1))

        self.generation_mae_clamp_1.update(preds=prediction[0][~mode_bool_d[:, 0]].clamp(-1, 1), target=data[0][~mode_bool_d[:, 0]].clamp(-1, 1))
        self.generation_mae_clamp_2.update(preds=prediction[1][~mode_bool_d[:, 1]].clamp(-1, 1), target=data[1][~mode_bool_d[:, 1]].clamp(-1, 1))
        self.generation_mae_clamp_3.update(preds=prediction[2][~mode_bool_d[:, 2]].clamp(-1, 1), target=data[2][~mode_bool_d[:, 2]].clamp(-1, 1))

        return {
            'mae': self.generation_mae,
            'mae_clamp': self.generation_mae_clamp,
            'mae_clamp_1': self.generation_mae_clamp_1,
            'mae_clamp_2': self.generation_mae_clamp_2,
            'mae_clamp_3': self.generation_mae_clamp_3,
        }

    @jaxtyped
    @typechecker
    def get_dict(
        self, *,
        data       : List[Float[torch.Tensor, 'b _chan_per_dom h w']],
        mode       : Float[torch.Tensor, 'b n_dom      '],
        batch_mixed: Float[torch.Tensor, 'b c_x_dom h w'],
        noise      : Float[torch.Tensor, 'b c_x_dom h w'],
        batch_recon: Float[torch.Tensor, 'b c_x_dom h w'],
        data0      : Float[torch.Tensor, 'b c_x_dom h w'],
    ) -> Dict:
        b, n_dom = mode.shape
        b, c_x_dom, h, w = batch_mixed.shape
        c = c_x_dom // n_dom

        predict_noise = self.params.loss.predict_noise

        if predict_noise:
            target = noise
            prediction = batch_recon
        else:
            target = torch.cat(data, dim=1)
            prediction = data0

        target = target.reshape(b * n_dom, c, h, w)
        prediction = prediction.reshape(b * n_dom, c, h, w)

        mode_bool = mode.bool().reshape(-1)

        supervised_prediction = prediction[mode_bool]
        supervised_target     = target[mode_bool]
        unsupervised_prediction = prediction[~mode_bool]
        unsupervised_target     = target[~mode_bool]

        self.supervised_mae.update(preds=supervised_prediction, target=supervised_target)
        self.unsupervised_mae.update(preds=unsupervised_prediction, target=unsupervised_target)
        return {
            'supervised_mae': self.supervised_mae,
            'unsupervised_mae': self.unsupervised_mae,
        }

    def compute_and_get(self):
        return {
            'mae': self.generation_mae.compute(),
            'mae_clamp': self.generation_mae_clamp.compute(),
            'mae_clamp_1': self.generation_mae_clamp_1.compute(),
            'mae_clamp_2': self.generation_mae_clamp_2.compute(),
            'mae_clamp_3': self.generation_mae_clamp_3.compute(),
        }
        return res


def get_metrics(params: MetricsParams):
    metrics = {
        'blender1': Blender1Metrics,
        'blender3': Blender3Metrics,
        'brats2020': BRATS2020_Metrics,
        'celeba1': CelebAMetrics1,
        'celeba3': CelebA3Metrics,
        'sunrgbd': SunRGBDMetrics,
    }

    if params.name not in metrics:
        raise ValueError(f'Unknown metric {params.name}')

    metric = metrics[params.name]

    return metric
