import cv2
import itertools
import numpy as np
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
from kornia import color
class Dense(nn.Module):
	def __init__(self, in_features, out_features, activation='relu', kernel_initializer='he_normal'):
		super(Dense, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.activation = activation
		self.kernel_initializer = kernel_initializer

		self.linear = nn.Linear(in_features, out_features)
		# initialization
		#if kernel_initializer == 'he_normal':
		#	nn.init.kaiming_normal_(self.linear.weight)
		#else:
		#	raise NotImplementedError

	def forward(self, inputs):
		#print("dense input size is", inputs.size())
		#print("in_feature is", self.in_features, "out_feature is", self.out_features)
		outputs = self.linear(inputs)
		if self.activation is not None:
			if self.activation == 'relu':
				outputs = nn.ReLU(inplace=True)(outputs)
		return outputs


class Conv2D(nn.Module):
	def __init__(self, in_channels, out_channels, kernel_size=3, activation='relu', strides=1):
		super(Conv2D, self).__init__()
		self.in_channels = in_channels
		self.out_channels = out_channels
		self.kernel_size = kernel_size
		self.activation = activation
		self.strides = strides

		self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, strides, int((kernel_size - 1) / 2))
		# default: using he_normal as the kernel initializer
		#nn.init.kaiming_normal_(self.conv.weight)

	def forward(self, inputs):
		outputs = self.conv(inputs)
		if self.activation is not None:
			if self.activation == 'relu':
				outputs = nn.ReLU(inplace=True)(outputs)
			else:
				raise NotImplementedError
		return outputs


class Flatten(nn.Module):
	def __init__(self):
		super(Flatten, self).__init__()

	def forward(self, input):
		#return input.view(input.size(0), -1)
		return input.reshape(input.size(0), -1)



class Discriminator(nn.Module):
	def __init__(self):
		super(Discriminator, self).__init__()
		self.model = nn.Sequential(
			Conv2D(3, 8, 3, strides=2, activation='relu'),
			Conv2D(8, 16, 3, strides=2, activation='relu'),
			Conv2D(16, 32, 3, strides=2, activation='relu'),
			Conv2D(32, 64, 3, strides=2, activation='relu'),
			Conv2D(64, 1, 3, activation=None))
	def forward(self, image):
		x = image - .5
		x = self.model(x)
		output = torch.mean(x)
		return output, x


def transform_net(encoded_image, args, global_step, resolution:list):
	#print("-------in transform net-----------")
	#print("encoded image size is", encoded_image.size())
	bs, c, w, h = encoded_image.size()

	ramp_fn = lambda ramp: np.min([global_step / ramp, 1.])

	rnd_bri = ramp_fn(args.rnd_bri_ramp) * args.rnd_bri
	rnd_hue = ramp_fn(args.rnd_hue_ramp) * args.rnd_hue
	rnd_brightness = get_rnd_brightness_torch(rnd_bri, rnd_hue, bs)  # [batch_size, 3, 1, 1]
	jpeg_quality = 100. - torch.rand(1)[0] * ramp_fn(args.jpeg_quality_ramp) * (100. - args.jpeg_quality)


	rnd_noise = torch.rand(1)[0] * ramp_fn(args.rnd_noise_ramp) * args.rnd_noise

	contrast_low = 1. - (1. - args.contrast_low) * ramp_fn(args.contrast_ramp)
	contrast_high = 1. + (args.contrast_high - 1.) * ramp_fn(args.contrast_ramp)
	contrast_params = [contrast_low, contrast_high]

	rnd_sat = torch.rand(1)[0] * ramp_fn(args.rnd_sat_ramp) * args.rnd_sat

	# blur
	N_blur = 7
	f = random_blur_kernel(probs=[.25, .25], N_blur=N_blur, sigrange_gauss=[1., 3.], sigrange_line=[.25, 1.],
								 wmin_line=3)

	f = f.cuda()
	encoded_image = F.conv2d(encoded_image, f, bias=None, padding=int((N_blur - 1) / 2))
	#print("after con2d is", encoded_image.size())
	# noise
	noise = torch.normal(mean=0, std=rnd_noise, size=encoded_image.size(), dtype=torch.float32)

	noise = noise.cuda()
	encoded_image = encoded_image + noise
	encoded_image = torch.clamp(encoded_image, 0., 1.)
	#print("after noise is", encoded_image.size())
	# contrast & brightness
	contrast_scale = torch.zeros(bs).uniform_(contrast_params[0], contrast_params[1])
	contrast_scale = contrast_scale.reshape(bs, 1, 1, 1)

	contrast_scale = contrast_scale.cuda()
	rnd_brightness = rnd_brightness.cuda()
	encoded_image = encoded_image * contrast_scale
	encoded_image = encoded_image + rnd_brightness
	encoded_image = torch.clamp(encoded_image, 0., 1.)
	#print("after contrast is", encoded_image.size())
	# saturation
	sat_weight = torch.FloatTensor([.3, .6, .1]).reshape(1, 3, 1, 1)

	sat_weight = sat_weight.cuda()
	encoded_image_lum = torch.mean(encoded_image * sat_weight, dim=1).unsqueeze_(1)
	encoded_image = (1 - rnd_sat) * encoded_image + rnd_sat * encoded_image_lum
	#print("after saturation is", encoded_image.size())
	# jpeg
	#encoded_image = encoded_image.reshape([-1, 3, 32, 32])
	encoded_image = encoded_image.reshape([-1, 3] + resolution)
	if not args.no_jpeg:
		encoded_image = jpeg_compress_decompress(encoded_image, rounding=round_only_at_0,
													   quality=jpeg_quality)
	encoded_image = torch.clamp(encoded_image, 0., 1.)
	return encoded_image

def image_loss_fn(image_input, encoded_image, l2_edge_gain, yuv_scales):
	size = (int(image_input.shape[2]), int(image_input.shape[3]))
	gain = 10
	falloff_speed = 4
	falloff_im = np.ones(size)
	for i in range(int(falloff_im.shape[0] / falloff_speed)):  # for i in range 100
		falloff_im[-i, :] *= (np.cos(4 * np.pi * i / size[0] + np.pi) + 1) / 2  # [cos[(4*pi*i/400)+pi] + 1]/2
		falloff_im[i, :] *= (np.cos(4 * np.pi * i / size[0] + np.pi) + 1) / 2  # [cos[(4*pi*i/400)+pi] + 1]/2
	for j in range(int(falloff_im.shape[1] / falloff_speed)):
		falloff_im[:, -j] *= (np.cos(4 * np.pi * j / size[0] + np.pi) + 1) / 2
		falloff_im[:, j] *= (np.cos(4 * np.pi * j / size[0] + np.pi) + 1) / 2
	falloff_im = 1 - falloff_im
	falloff_im = torch.from_numpy(falloff_im).float()
	falloff_im = falloff_im.cuda()
	falloff_im *= l2_edge_gain

	encoded_image_yuv = color.rgb_to_yuv(encoded_image)
	image_input_yuv = color.rgb_to_yuv(image_input)
	im_diff = encoded_image_yuv - image_input_yuv
	im_diff += im_diff * falloff_im.unsqueeze_(0)
	yuv_loss = torch.mean((im_diff) ** 2, dim=[0, 2, 3])
	yuv_scales = torch.Tensor(yuv_scales)
	yuv_scales = yuv_scales.cuda()
	image_loss = torch.dot(yuv_loss, yuv_scales)
	return image_loss

def random_blur_kernel(probs, N_blur, sigrange_gauss, sigrange_line, wmin_line):
	N = N_blur
	coords = torch.from_numpy(np.stack(np.meshgrid(range(N_blur), range(N_blur), indexing='ij'), axis=-1)) - (
				0.5 * (N - 1))  # （7,7,2)
	manhat = torch.sum(torch.abs(coords), dim=-1)  # (7, 7)

	# nothing, default
	vals_nothing = (manhat < 0.5).float()  # (7, 7)

	# gauss
	sig_gauss = torch.rand(1)[0] * (sigrange_gauss[1] - sigrange_gauss[0]) + sigrange_gauss[0]
	vals_gauss = torch.exp(-torch.sum(coords ** 2, dim=-1) / 2. / sig_gauss ** 2)

	# line
	theta = torch.rand(1)[0] * 2. * np.pi
	v = torch.FloatTensor([torch.cos(theta), torch.sin(theta)])  # (2)
	dists = torch.sum(coords * v, dim=-1)  # (7, 7)

	sig_line = torch.rand(1)[0] * (sigrange_line[1] - sigrange_line[0]) + sigrange_line[0]
	w_line = torch.rand(1)[0] * (0.5 * (N - 1) + 0.1 - wmin_line) + wmin_line

	vals_line = torch.exp(-dists ** 2 / 2. / sig_line ** 2) * (manhat < w_line)  # (7, 7)

	t = torch.rand(1)[0]
	vals = vals_nothing
	if t < (probs[0] + probs[1]):
		vals = vals_line
	else:
		vals = vals
	if t < probs[0]:
		vals = vals_gauss
	else:
		vals = vals

	v = vals / torch.sum(vals)  # 归一化 (7, 7)
	z = torch.zeros_like(v)
	f = torch.stack([v, z, z, z, v, z, z, z, v], dim=0).reshape([3, 3, N, N])
	return f


def get_rand_transform_matrix(image_size, d, batch_size):
	Ms = np.zeros((batch_size, 2, 3, 3))
	for i in range(batch_size):
		tl_x = random.uniform(-d, d)  # Top left corner, top
		tl_y = random.uniform(-d, d)  # Top left corner, left
		bl_x = random.uniform(-d, d)  # Bot left corner, bot
		bl_y = random.uniform(-d, d)  # Bot left corner, left
		tr_x = random.uniform(-d, d)  # Top right corner, top
		tr_y = random.uniform(-d, d)  # Top right corner, right
		br_x = random.uniform(-d, d)  # Bot right corner, bot
		br_y = random.uniform(-d, d)  # Bot right corner, right

		rect = np.array([
			[tl_x, tl_y],
			[tr_x + image_size, tr_y],
			[br_x + image_size, br_y + image_size],
			[bl_x, bl_y + image_size]], dtype="float32")

		dst = np.array([
			[0, 0],
			[image_size, 0],
			[image_size, image_size],
			[0, image_size]], dtype="float32")

		M = cv2.getPerspectiveTransform(rect, dst)
		M_inv = np.linalg.inv(M)
		Ms[i, 0, :, :] = M_inv
		Ms[i, 1, :, :] = M
	Ms = torch.from_numpy(Ms).float()

	return Ms


def get_rnd_brightness_torch(rnd_bri, rnd_hue, batch_size):
	rnd_hue = torch.FloatTensor(batch_size, 3, 1, 1).uniform_(-rnd_hue, rnd_hue)
	rnd_brightness = torch.FloatTensor(batch_size, 1, 1, 1).uniform_(-rnd_bri, rnd_bri)
	return rnd_hue + rnd_brightness


# reference: https://github.com/mlomnitz/DiffJPEG.git
y_table = np.array(
	[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60,
	                                    55], [14, 13, 16, 24, 40, 57, 69, 56],
	 [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103,
	                                    77], [24, 35, 55, 64, 81, 104, 113, 92],
	 [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
	dtype=np.float32).T

y_table = nn.Parameter(torch.from_numpy(y_table), requires_grad=False)
c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66],
                            [24, 26, 56, 99], [47, 66, 99, 99]]).T
c_table = nn.Parameter(torch.from_numpy(c_table), requires_grad=False)


# 1. RGB -> YCbCr
class rgb_to_ycbcr_jpeg(nn.Module):
	""" Converts RGB image to YCbCr
	Input:
		image(tensor): batch x 3 x height x width
	Outpput:
		result(tensor): batch x height x width x 3
	"""

	def __init__(self):
		super(rgb_to_ycbcr_jpeg, self).__init__()
		matrix = np.array(
			[[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5],
			 [0.5, -0.418688, -0.081312]], dtype=np.float32).T
		self.shift = nn.Parameter(torch.tensor([0., 128., 128.])).cuda()
		self.matrix = nn.Parameter(torch.from_numpy(matrix)).cuda()

	def forward(self, image):
		image = image.permute(0, 2, 3, 1)
		#print(self.matrix)
		result = torch.tensordot(image, self.matrix, dims=1) + self.shift
		result.view(image.shape)
		return result


# 2. Chroma subsampling
class chroma_subsampling(nn.Module):
	""" Chroma subsampling on CbCv channels
	Input:
		image(tensor): batch x height x width x 3
	Output:
		y(tensor): batch x height x width
		cb(tensor): batch x height/2 x width/2
		cr(tensor): batch x height/2 x width/2
	"""

	def __init__(self):
		super(chroma_subsampling, self).__init__()

	def forward(self, image):
		image_2 = image.permute(0, 3, 1, 2).clone()
		avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2),
		                        count_include_pad=False)
		cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1))
		cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1))
		cb = cb.permute(0, 2, 3, 1)
		cr = cr.permute(0, 2, 3, 1)
		return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)


# 3. Block splitting
class block_splitting(nn.Module):
	""" Splitting image into patches
	Input:
		image(tensor): batch x height x width
	Output:
		patch(tensor):  batch x h*w/64 x h x w
	"""

	def __init__(self):
		super(block_splitting, self).__init__()
		self.k = 8

	def forward(self, image):
		height, width = image.shape[1:3]
		batch_size = image.shape[0]
		image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
		image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
		return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)


# 4. DCT
class dct_8x8(nn.Module):
	""" Discrete Cosine Transformation
	Input:
		image(tensor): batch x height x width
	Output:
		dcp(tensor): batch x height x width
	"""

	def __init__(self):
		super(dct_8x8, self).__init__()
		tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
		for x, y, u, v in itertools.product(range(8), repeat=4):
			tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos(
				(2 * y + 1) * v * np.pi / 16)
		alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
		#
		self.tensor = nn.Parameter(torch.from_numpy(tensor).float()).cuda()
		self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()).cuda()

	def forward(self, image):
		image = image - 128
		result = self.scale * torch.tensordot(image, self.tensor, dims=2)
		result.view(image.shape)
		return result


# 5. Quantization
class y_quantize(nn.Module):
	""" JPEG Quantization for Y channel
	Input:
		image(tensor): batch x height x width
		rounding(function): rounding function to use
		factor(float): Degree of compression
	Output:
		image(tensor): batch x height x width
	"""

	def __init__(self, rounding, factor=1):
		super(y_quantize, self).__init__()
		self.rounding = rounding
		self.factor = factor.cuda()
		self.y_table = y_table.cuda()

	def forward(self, image):
		image = image.float() / (self.y_table * self.factor)
		image = self.rounding(image)
		return image


class c_quantize(nn.Module):
	""" JPEG Quantization for CrCb channels
	Input:
		image(tensor): batch x height x width
		rounding(function): rounding function to use
		factor(float): Degree of compression
	Output:
		image(tensor): batch x height x width
	"""

	def __init__(self, rounding, factor=1):
		super(c_quantize, self).__init__()
		self.rounding = rounding
		self.factor = factor.cuda()
		self.c_table = c_table.cuda()

	def forward(self, image):
		image = image.float() / (self.c_table * self.factor)
		image = self.rounding(image)
		return image


class compress_jpeg(nn.Module):
	""" Full JPEG compression algortihm
	Input:
		imgs(tensor): batch x 3 x height x width
		rounding(function): rounding function to use
		factor(float): Compression factor
	Ouput:
		compressed(dict(tensor)): batch x h*w/64 x 8 x 8
	"""

	def __init__(self, rounding=torch.round, factor=1):
		super(compress_jpeg, self).__init__()
		self.l1 = nn.Sequential(
			rgb_to_ycbcr_jpeg(),
			chroma_subsampling()
		)
		self.l2 = nn.Sequential(
			block_splitting(),
			dct_8x8()
		)
		self.c_quantize = c_quantize(rounding=rounding, factor=factor)
		self.y_quantize = y_quantize(rounding=rounding, factor=factor)

	def forward(self, image):
		y, cb, cr = self.l1(image * 255)
		components = {'y': y, 'cb': cb, 'cr': cr}
		for k in components.keys():
			comp = self.l2(components[k])
			if k in ('cb', 'cr'):
				comp = self.c_quantize(comp)
			else:
				comp = self.y_quantize(comp)

			components[k] = comp

		return components['y'], components['cb'], components['cr']


# -5. Dequantization
class y_dequantize(nn.Module):
	""" Dequantize Y channel
	Inputs:
		image(tensor): batch x height x width
		factor(float): compression factor
	Outputs:
		image(tensor): batch x height x width
	"""

	def __init__(self, factor=1):
		super(y_dequantize, self).__init__()
		self.y_table = y_table.cuda()
		self.factor = factor.cuda()

	def forward(self, image):
		return image * (self.y_table * self.factor)


class c_dequantize(nn.Module):
	""" Dequantize CbCr channel
	Inputs:
		image(tensor): batch x height x width
		factor(float): compression factor
	Outputs:
		image(tensor): batch x height x width
	"""

	def __init__(self, factor=1):
		super(c_dequantize, self).__init__()
		self.factor = factor.cuda()
		self.c_table = c_table.cuda()

	def forward(self, image):
		return image * (self.c_table * self.factor)


# -4. Inverse DCT
class idct_8x8(nn.Module):
	""" Inverse discrete Cosine Transformation
	Input:
		dcp(tensor): batch x height x width
	Output:
		image(tensor): batch x height x width
	"""

	def __init__(self):
		super(idct_8x8, self).__init__()
		alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
		self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()).cuda()
		tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
		for x, y, u, v in itertools.product(range(8), repeat=4):
			tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
				(2 * v + 1) * y * np.pi / 16)
		self.tensor = nn.Parameter(torch.from_numpy(tensor).float()).cuda()

	def forward(self, image):
		image = image * self.alpha
		result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
		result.view(image.shape)
		return result


# -3. Block joining
class block_merging(nn.Module):
	""" Merge pathces into image
	Inputs:
		patches(tensor) batch x height*width/64, height x width
		height(int)
		width(int)
	Output:
		image(tensor): batch x height x width
	"""

	def __init__(self):
		super(block_merging, self).__init__()

	def forward(self, patches, height, width):
		k = 8
		batch_size = patches.shape[0]
		image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
		image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
		return image_transposed.contiguous().view(batch_size, height, width)


# -2. Chroma upsampling
class chroma_upsampling(nn.Module):
	""" Upsample chroma layers
	Input:
		y(tensor): y channel image
		cb(tensor): cb channel
		cr(tensor): cr channel
	Ouput:
		image(tensor): batch x height x width x 3
	"""

	def __init__(self):
		super(chroma_upsampling, self).__init__()

	def forward(self, y, cb, cr):
		def repeat(x, k=2):
			height, width = x.shape[1:3]
			x = x.unsqueeze(-1)
			x = x.repeat(1, 1, k, k)
			x = x.view(-1, height * k, width * k)
			return x

		cb = repeat(cb)
		cr = repeat(cr)

		return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)


# -1: YCbCr -> RGB
class ycbcr_to_rgb_jpeg(nn.Module):
	""" Converts YCbCr image to RGB JPEG
	Input:
		image(tensor): batch x height x width x 3
	Outpput:
		result(tensor): batch x 3 x height x width
	"""

	def __init__(self):
		super(ycbcr_to_rgb_jpeg, self).__init__()

		matrix = np.array(
			[[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
			dtype=np.float32).T
		self.shift = nn.Parameter(torch.tensor([0, -128., -128.])).cuda()
		self.matrix = nn.Parameter(torch.from_numpy(matrix)).cuda()

	def forward(self, image):
		result = torch.tensordot(image + self.shift, self.matrix, dims=1)
		result.view(image.shape)
		return result.permute(0, 3, 1, 2)


class decompress_jpeg(nn.Module):
	""" Full JPEG decompression algortihm
	Input:
		compressed(dict(tensor)): batch x h*w/64 x 8 x 8
		rounding(function): rounding function to use
		factor(float): Compression factor
	Ouput:
		image(tensor): batch x 3 x height x width
	"""

	def __init__(self, height, width, rounding=torch.round, factor=1):
		super(decompress_jpeg, self).__init__()
		self.c_dequantize = c_dequantize(factor=factor)
		self.y_dequantize = y_dequantize(factor=factor)
		self.idct = idct_8x8()
		self.merging = block_merging()
		self.chroma = chroma_upsampling()
		self.colors = ycbcr_to_rgb_jpeg()

		self.height, self.width = height, width

	def forward(self, y, cb, cr):
		components = {'y': y, 'cb': cb, 'cr': cr}
		for k in components.keys():
			if k in ('cb', 'cr'):
				comp = self.c_dequantize(components[k])
				height, width = int(self.height / 2), int(self.width / 2)
			else:
				comp = self.y_dequantize(components[k])
				height, width = self.height, self.width
			comp = self.idct(comp)
			components[k] = self.merging(comp, height, width)
		#
		image = self.chroma(components['y'], components['cb'], components['cr'])
		image = self.colors(image)

		image = torch.min(255 * torch.ones_like(image),
		                  torch.max(torch.zeros_like(image), image))
		return image / 255


def diff_round(x):
	""" Differentiable rounding function
	Input:
		x(tensor)
	Output:
		x(tensor)
	"""
	return torch.round(x) + (x - torch.round(x)) ** 3


def round_only_at_0(x):
	cond = (torch.abs(x) < 0.5).float()
	return cond * (x ** 3) + (1 - cond) * x


def quality_to_factor(quality):
	""" Calculate factor corresponding to quality
	Input:
		quality(float): Quality for jpeg compression
	Output:
		factor(float): Compression factor
	"""
	if quality < 50:
		quality = 5000. / quality
	else:
		quality = 200. - quality * 2
	return quality / 100.


def jpeg_compress_decompress(image,
                             #  downsample_c=True,
                             rounding=round_only_at_0,
                             quality=80):
	# image_r = image * 255
	height, width = image.shape[2:4]
	# orig_height, orig_width = height, width
	# if height % 16 != 0 or width % 16 != 0:
	#     # Round up to next multiple of 16
	#     height = ((height - 1) // 16 + 1) * 16
	#     width = ((width - 1) // 16 + 1) * 16

	#     vpad = height - orig_height
	#     wpad = width - orig_width
	#     top = vpad // 2
	#     bottom = vpad - top
	#     left = wpad // 2
	#     right = wpad - left
	# #image = tf.pad(image, [[0, 0], [top, bottom], [left, right], [0, 0]], 'SYMMETRIC')
	# image = torch.pad(image, [[0, 0], [0, vpad], [0, wpad], [0, 0]], 'reflect')

	factor = quality_to_factor(quality)

	compress = compress_jpeg(rounding=rounding, factor=factor)
	decompress = decompress_jpeg(height, width, rounding=rounding, factor=factor)

	y, cb, cr = compress(image)
	recovered = decompress(y, cb, cr)

	return recovered
