from typing import Iterable, Optional
import pandas as pd
import torch
from torchvision.io import encode_jpeg, decode_jpeg
import math
from data_utils.statistics import psnr, mse
from nn_compression._interfaces import CompressedTensor


class JpegArtifact(CompressedTensor):
    def __init__(self, img: torch.Tensor, quality: int) -> None:
        """Creates a JPEG artifact from an image tensor.

        Usage example:
        >>> import torch
        >>> from nn_compression.evaluation import JpegArtifact
        >>> img = torch.rand(3, 256, 256)
        >>> artifact = JpegArtifact(img, quality=90)
        >>> print(artifact.psnr, artifact.bpp)
        """
        if quality < 1 or quality > 100:
            raise ValueError("Quality must be between 1 and 100.")
        self.img = img
        self.original_img = img

        self.img_min = self.original_img.min()
        self.img_max = self.original_img.max()

        if img.dtype != torch.uint8:
            self.img = self.normalise_jpeg(img)

        self.jpeg = encode_jpeg(self.img, quality=quality)
        self.rec = decode_jpeg(self.jpeg)

        self._bpp = len(self.jpeg) * 8 / math.prod(self.img.shape)
        self.error_residuals = ((self.rec.to(float) - self.img.to(float))).abs().mean(axis=0)[None, :, :]  # type: ignore
        self.psnr = psnr(self.rec, self.img, pixel_max=255)
        self.mse = mse(self.rec, self.img)

    def normalise_jpeg(self, img: torch.Tensor) -> torch.Tensor:
        return ((img - img.min()) / (img.max() - img.min()) * 255).to(torch.uint8)

    def unnormalise_jpeg(self, img: torch.Tensor) -> torch.Tensor:
        return img / 255 * (self.img_max - self.img_min) + self.img_min

    @property
    def bpp(self):
        return self._bpp

    @property
    def x(self) -> torch.Tensor:
        return self.rec

    @staticmethod
    def rd_curve(
        dataset: Iterable[torch.Tensor],
        max_imgs: Optional[int] = None,
        quality_points: int = 20,
        reduce: bool = True,
    ) -> pd.DataFrame:
        """Creates an RD curve for the given dataset. The RD curve is a DataFrame with
        columns 'quality' (referring to encoding quality of JPEG), 'psnr', 'mse' and 'bpp'.

        MSE is calculated IN PIXEL SPACE, not in the normalised space. This is important. If you
        not want to bother with this, use PSNR, which is always normalised to the space.

        The input must be a list or similar iterable with tensors of shape (C,H,W). Preferrably, pass tensors with uint8 as dtype,
        then normalisation is skipped, as we assume that the full data range is simply 0-255. Otherwise,
        normalisation is done on a per-image basis, which might be a bit skewed towards better performance of
        JPEG.

        This function can take a long time to run, depending on the size of the dataset and the quality_step.

        Usage example:
        >>> import matplotlib.pyplot as plt
        >>> import torch
        >>> from nn_compression.evaluation import JpegArtifact
        >>> imgs = [torch.rand(3, 256, 256) for _ in range(100)]
        >>> df = JpegArtifact.rd_curve(imgs)
        >>> plt.plot(df.bpp, df.psnr, ".")
        >>> plt.show()

        Args:
            max_imgs: Maximum number of images to process. If None, all images are processed.
            reduce: If False, each row corresponds to one image in the dataset, else the
                RD curve is reduced to a single point per quality level by averaging over all images.
            quality_step: How finely the quality levels (which go from 1-100) should be sampled.
                Default is 20 points for the RD curve.

        """
        psnrs = []
        bpps = []
        qualities = []
        mses = []
        for quality in range(1, 100, quality_step):
            for i, img in enumerate(dataset, 1):
                if max_imgs is not None and i >= max_imgs:
                    break
                artifact = JpegArtifact(img, quality)
                qualities.append(quality)
                psnrs.append(artifact.psnr)
                bpps.append(artifact.bpp)
                mses.append(artifact.mse)
        df = pd.DataFrame(dict(quality=qualities, psnr=psnrs, bpp=bpps, mse=mses))
        if reduce:
            return df.groupby("quality").mean().reset_index()
        return pd.DataFrame(dict(quality=qualities, psnr=psnrs, bpp=bpps, mse=mses))
