from typing import Optional

import lightning as L
import torch
from ito_vision.discretizations import Discretization
from ito_vision.methods import IterativeRefinementMethod
from ito_vision.samplers import Sampler
from torchmetrics.image import PeakSignalNoiseRatio as PSNR
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS


class IterativeIMG2IMG(L.LightningModule):
    def __init__(
        self,
        backbone: torch.nn.Module,
        method: IterativeRefinementMethod,
        sampler: Sampler,
        discretization: Discretization,
        optimizer: torch.optim.Optimizer,
        vae: Optional[torch.nn.Module] = None,
        lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
    ):
        super().__init__()
        self.save_hyperparameters(logger=False)
        self.backbone = backbone
        self.method = method
        self.sampler = sampler
        self.discretization = discretization

        if vae:
            self.vae = vae.eval()
        else:
            self.vae = None  # type: ignore

        self.val_psnr = PSNR(data_range=(-1, 1))
        self.val_ssim = SSIM(data_range=(-1, 1))
        self.val_lpips = LPIPS().eval()

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

    def training_step(self, batch, batch_idx: int):
        y, x0 = batch["y"], batch["x0"]

        if self.vae:
            if (
                "y_latent" in batch and "x0_latent" in batch
            ):  # use precomputed latents if available
                y = batch["y_latent"]
                x0 = batch["x0_latent"]
            else:
                y = self.vae.encode(y)
                x0 = self.vae.encode(x0)

        kwargs = dict({k: v for k, v in batch.items() if k not in ["y", "x0"]})

        loss = self.method.loss(self.backbone, x0, y=y, **kwargs)

        self.log("train/loss", loss, on_epoch=False, on_step=True)

        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx: int):
        y, x0 = batch["y"], batch["x0"]

        if self.vae:
            y_latent = self.vae.encode(y)
        else:
            y_latent = y.clamp(-1, 1).clone()

        kwargs = dict({k: v for k, v in batch.items() if k not in ["y", "x0"]})

        pred_z0, _, _ = self.method.sample(
            self.discretization,
            self.sampler,
            self.backbone,
            self.method.base_distribution(y_latent),
            y_latent,
            return_trajectory=False,
            **kwargs,
        )

        if self.vae:
            pred_x0 = self.vae.decode(pred_z0)
        else:
            pred_x0 = pred_z0.clamp(-1, 1)

        psnr = self.val_psnr(pred_x0, x0)
        ssim = self.val_ssim(pred_x0, x0)
        lpips = self.val_lpips(pred_x0, x0)

        self.log("val/psnr", psnr, on_epoch=True, on_step=False, sync_dist=True)
        self.log("val/ssim", ssim, on_epoch=True, on_step=False, sync_dist=True)
        self.log("val/lpips", lpips, on_epoch=True, on_step=False, sync_dist=True)

        pred_x0 = ((pred_x0 + 1.0) * 127.5).clamp(0, 255).byte()
        x0 = ((x0 + 1.0) * 127.5).clamp(0, 255).byte()
        y = ((y + 1.0) * 127.5).clamp(0, 255).byte()

        x_log = torch.cat([y, pred_x0, x0], dim=2)  # input, output, target

        return {"wandb_image_logger": {"val/samples": {"images": x_log}}}

    def configure_optimizers(self):
        params = list(filter(lambda p: p.requires_grad, self.backbone.parameters()))
        optimizer = self.optimizer(params=params)

        out = {"optimizer": optimizer}

        if self.lr_scheduler is not None:
            lr_scheduler = self.lr_scheduler(
                optimizer=optimizer,
                T_max=self.trainer.estimated_stepping_batches,
            )
            out["lr_scheduler"] = {
                "scheduler": lr_scheduler,
                "interval": "step",
                "frequency": 1,
            }

        return out
