# Pytorch
import torch
import torch.nn as nn

# Local
from utils.DiffJPEG.compression import compress_jpeg
from utils.DiffJPEG.decompression import decompress_jpeg

from utils.DiffJPEG.utils import diff_round, quality_to_factor


class DiffJPEG(nn.Module):
    def __init__(self, height, width, differentiable=True, quality=80):
        """Initialize the DiffJPEG layer
        Inputs:
            height(int): Original image hieght
            width(int): Original image width
            differentiable(bool): If true uses custom differentiable
                rounding function, if false uses standrard torch.round
            quality(float): Quality factor for jpeg compression scheme.
        """
        super(DiffJPEG, self).__init__()
        if differentiable:
            rounding = diff_round
        else:
            rounding = torch.round
        factor = quality_to_factor(quality)
        self.compress = compress_jpeg(rounding=rounding, factor=factor)
        self.decompress = decompress_jpeg(
            height, width, rounding=rounding, factor=factor
        )

        # by default set requires_grad to False
        self.compress.requires_grad_(False)
        self.decompress.requires_grad_(False)

    def forward(self, x):
        """ """
        y, cb, cr = self.compress(x)
        recovered = self.decompress(y, cb, cr)
        return recovered
