import torch
import torch.nn as nn
from torchvision import transforms

import augly.image.functional as aug_functional


image_mean = torch.Tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
image_std = torch.Tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)

def normalize_img(x):
    """ Normalize image to approx. [-1,1] """
    return (x - image_mean.to(x.device)) / image_std.to(x.device)

def unnormalize_img(x):
    """ Unnormalize image to [0,1] """
    return (x * image_std.to(x.device)) + image_mean.to(x.device)

def round_pixel(x):
    """ 
    Round pixel values to nearest integer. 
    Args:
        x: Image tensor with values approx. between [-1,1]
    Returns:
        y: Rounded image tensor with values approx. between [-1,1]
    """
    x_pixel = 255 * unnormalize_img(x)
    y = torch.round(x_pixel).clamp(0, 255)
    y = normalize_img(y/255.0)
    return y

def clamp_pixel(x):
    """ 
    Clamp pixel values to 0 255. 
    Args:
        x: Image tensor with values approx. between [-1,1]
    Returns:
        y: Rounded image tensor with values approx. between [-1,1]
    """
    x_pixel = 255 * unnormalize_img(x)
    y = x_pixel.clamp(0, 255)
    y = normalize_img(y/255.0)
    return y


class Jpeg(nn.Module):
    def __init__(self, Q):
        super(Jpeg, self).__init__()
        self.Q = Q

        self.to_pil = transforms.ToPILImage()
        self.to_tensor = transforms.ToTensor()

        self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.unnormalize_img = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225])

    def forward(self, image):

        with torch.no_grad():
            img_clip = clamp_pixel(image)

            img_aug = torch.zeros_like(img_clip, device=img_clip.device)

            for ii,img in enumerate(img_clip):
                pil_img = self.to_pil(self.unnormalize_img(img))
                img_aug[ii] = self.to_tensor(aug_functional.encoding_quality(pil_img, quality=self.Q))

            img_aug = self.normalize_img(img_aug)

            img_gap = img_aug - image
            img_gap = img_gap.detach()

        img_aug = image + img_gap

        return img_aug
    
    def __repr__(self):
        return "Jpeg(Q=" + str(self.Q) + ")"
