import json
import wandb
import torch
from PIL import Image
from pathlib import Path

from .base import BaseMetric

import logging

log = logging.getLogger(__name__)


class CTSaver(BaseMetric):
    def __init__(self):
        super(CTSaver, self).__init__()
        self.hu_max = 80.0
        self.save_dir = Path(wandb.run.dir) / "media" / "rsna"
        self.idx = -1
        self.reset()

    def __str__(self):
        return "CTSaver"

    @torch.no_grad()
    def forward(self, x_0, x_0_hat):
        # increment counter
        self.counter += x_0.shape[0]

        for x_0_, x_0_hat_ in zip(x_0, x_0_hat):
            # get idx
            self.idx += 1
            i = self.idx

            # make dir for this data
            data_dir = self.save_dir / str(i)
            data_dir.mkdir(parents=True, exist_ok=True)

            hu_x_0_ = _attenuation_to_HU(x_0_).cpu()
            hu_x_0_hat_ = _attenuation_to_HU(x_0_hat_).cpu()

            # save the brain range

            # clone HU image and zero out whats outside of brain range
            hu_x_0_brain = hu_x_0_.clone()
            hu_x_0_hat_brain = hu_x_0_hat_.clone()

            # clamp to brain range
            hu_x_0_brain = torch.clamp(hu_x_0_brain, min=0.0, max=80.0)
            hu_x_0_hat_brain = torch.clamp(hu_x_0_hat_brain, min=0.0, max=80.0)

            # normalize to 8-bit range [0, 255]
            hu_max = self.hu_max

            scaled_hu_x_0_brain = ((hu_x_0_brain / hu_max) * 255).to(torch.uint8)

            scaled_hu_x_0_hat_brain = ((hu_x_0_hat_brain / hu_max) * 255).to(
                torch.uint8
            )

            # save image as PNG
            scaled_hu_x_0_brain = Image.fromarray(scaled_hu_x_0_brain[0].numpy())
            scaled_hu_x_0_brain.save(data_dir / "x_0_brain.png")

            scaled_hu_x_0_hat_brain = Image.fromarray(
                scaled_hu_x_0_hat_brain[0].numpy()
            )
            scaled_hu_x_0_hat_brain.save(data_dir / "x_0_hat_brain.png")

            # extract min and max values
            hu_x_0_min, hu_x_0_max = hu_x_0_.min().item(), hu_x_0_.max().item()
            hu_x_0_hat_min, hu_x_0_hat_max = (
                hu_x_0_hat_.min().item(),
                hu_x_0_hat_.max().item(),
            )

            # normalize to 16-bit range [0, 65535]
            hu_x_0_scaled = (
                (hu_x_0_ - hu_x_0_min) / (hu_x_0_max - hu_x_0_min) * 65535
            ).to(torch.uint16)
            hu_x_0_hat_scaled = (
                (hu_x_0_hat_ - hu_x_0_hat_min)
                / (hu_x_0_hat_max - hu_x_0_hat_min)
                * 65535
            ).to(torch.uint16)

            # save image as PNG
            hu_x_0_scaled = Image.fromarray(hu_x_0_scaled[0].numpy())
            hu_x_0_scaled.save(data_dir / "x_0.png", format="PNG", compress_level=0)

            hu_x_0_hat_scaled = Image.fromarray(hu_x_0_hat_scaled[0].numpy())
            hu_x_0_hat_scaled.save(
                data_dir / "x_0_hat.png", format="PNG", compress_level=0
            )

            # save metadata
            metadata = {
                "hu_x_0_min": hu_x_0_min,
                "hu_x_0_max": hu_x_0_max,
                "hu_x_0_hat_min": hu_x_0_hat_min,
                "hu_x_0_hat_max": hu_x_0_hat_max,
            }

            with open(data_dir / f"{i}.json", "w") as f:
                json.dump(metadata, f)

            # log to wandb
            wandb.log(
                {
                    "rsna/x_0_brain": wandb.Image(hu_x_0_brain),
                    "rsna/x_0_hat_brain": wandb.Image(hu_x_0_hat_brain),
                }
            )

    def compute_and_log(self, fabric, log_prefix=""):
        # reset
        self.reset()

    def reset(self):
        self.counter = 0


def _attenuation_to_HU(image, scale_only=False):
    """
    scale_only == False:
    μ = (HU + 1000) * 0.1 / 1000

    scale_only == True:
    μ = HU * 0.1 / 1000
    """
    if scale_only:
        return image * 1000.0 / 0.1
    else:
        return (image * 1000.0 / 0.1) - 1000.0
