
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

import torch
import random
import numpy as np
import torchvision
from torchvision import transforms
import argparse
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torch.nn.functional as F
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from PIL import Image

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
    parser.add_argument("--imagenet_root", type=str, default="/data22/datasets/ilsvrc2012/", help="ImageNet dataset root directory")
    parser.add_argument("--output_dir", type=str, default="./results/image_compression/JPEG/", help="Directory to save output images")
    return parser.parse_args()

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name)
        
        # Ensure image is in 768x512 orientation
        # if image.size[0] < image.size[1]:
        #     image = image.transpose(method=Image.Transpose.ROTATE_90)  # Rotate to 768x512
        
        if self.transform:
            image = self.transform(image)
        
        return image, 0

def get_test_data(args, resize):
    imagenet_transform = transforms.Compose([
        transforms.Resize(224), 
        transforms.CenterCrop(224),
        # transforms.RandomCrop(crop),
        transforms.ToTensor(),
    ])
    return CustomImageDataset(root_dir='/data22/aho/KODAK/', transform=imagenet_transform)

def psnr_batch(img1, img2):
    mse = F.mse_loss(img1, img2, reduction='none')
    mse = mse.view(mse.size(0), -1).mean(dim=1)
    psnr_values = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr_values

def ms_ssim_to_db(ms_ssim):
    return -10 * np.log10(1 - ms_ssim)

def ms_ssim_batch(img1, img2, data_range=1.0):

    ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=data_range)
    
    # Calculate MS-SSIM for each image pair in the batch
    ms_ssim_values = [ms_ssim(img1[i].unsqueeze(0), img2[i].unsqueeze(0)).item() for i in range(img1.size(0))]
    
    # Convert MS-SSIM values to dB
    ms_ssim_db_values = [ms_ssim_to_db(value) for value in ms_ssim_values]
    
    return ms_ssim_db_values

def save_image(img, path):
    plt.imsave(path, np.transpose(img, (1, 2, 0)))

def compress_and_evaluate_jpeg(images, quality, output_dir):
    size_list = []
    psnr_list = []
    ms_ssim_list = []
    
    for i, image in enumerate(images):
        image_pil = transforms.ToPILImage()(image.cpu())
        output_path = os.path.join(output_dir, f"compressed_{quality}_{i}.jpg")
        image_pil.save(output_path, 'JPEG', quality=quality)
        
        compressed_image = Image.open(output_path)
        compressed_image = transforms.ToTensor()(compressed_image).unsqueeze(0)
        
        # Calculate size
        size_list.append(os.path.getsize(output_path) / 1024)
        
        # Calculate PSNR for the image
        psnr_value = psnr_batch(compressed_image, image.unsqueeze(0))
        psnr_list.extend(psnr_value)
        
        # Calculate MS-SSIM for the image
        ms_ssim_value = ms_ssim_batch(compressed_image, image.unsqueeze(0))
        ms_ssim_list.extend(ms_ssim_value)
    
    return np.mean(size_list), np.mean(psnr_list), np.mean(ms_ssim_list)

def main():
    args = parse_args()
    torch.manual_seed(0)
    random.seed(10)
    np.random.seed(0)
    
    test_dataset = get_test_data(args, 224)

    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=1)
    
    qualities = range(1,32,3)
    results = []

    for quality in qualities:
        sizes = []
        psnrs = []
        ms_ssims = []
        
        for idx, data in enumerate(test_loader):
            images, labels = data
            images = images
            
            avg_size, avg_psnr, avg_ms_ssim = compress_and_evaluate_jpeg(images, quality, args.output_dir)
            sizes.append(avg_size)
            psnrs.append(avg_psnr)
            ms_ssims.append(avg_ms_ssim)
        
        avg_size = np.mean(sizes)
        avg_psnr = np.mean(psnrs)
        avg_ms_ssim = np.mean(ms_ssims)

        # Convert image size to bits
        avg_size = avg_size * 8 * 1024
        
        # Calculate bpp (bits per pixel)
        bpp = avg_size / (224 * 224)
        
        results.append((float(bpp), float(avg_psnr), float(avg_ms_ssim)))
        print(f'Quality: {quality}, bpp: {bpp:.2f}, PSNR: {avg_psnr:.2f}dB, MS_SSIM: {avg_ms_ssim:.2f}dB')
        
    print(results)
if __name__ == "__main__":
    main()
