import os
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
import argparse
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
import torch.nn.functional as F

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=2000, help='Batch size for processing frames')
    parser.add_argument('--video_path', type=str, required=True, help='Path to the input video file')
    parser.add_argument('--output_dir', type=str, default="./results/JPEG/", help='Directory to save output images and videos')
    return parser.parse_args()

class VideoDataset(Dataset):
    def __init__(self, video_path, transform=None):
        self.video_path = video_path
        self.transform = transform
        self.frames = self._load_frames()

    def _load_frames(self):
        cap = cv2.VideoCapture(self.video_path)
        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame))
        cap.release()
        return frames

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, frame_index):
        frame = self.frames[frame_index]
        if self.transform:
            frame = self.transform(frame)
        return frame, frame_index

def get_video_frames(args, resize):
    transform = transforms.Compose([
        transforms.Resize((resize, resize), antialias=True),
        transforms.ToTensor(),
    ])
    video_dataset = VideoDataset(video_path=args.video_path, transform=transform)
    return video_dataset

def psnr_batch(img1, img2):
    mse = F.mse_loss(img1, img2, reduction='none')
    mse = mse.view(mse.size(0), -1).mean(dim=1)
    psnr_values = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr_values

def ms_ssim_to_db(ms_ssim):
    return -10 * np.log10(1 - ms_ssim)

def ms_ssim_batch(img1, img2, data_range=1.0):
    ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=data_range)
    ms_ssim_values = [ms_ssim(img1[i].unsqueeze(0), img2[i].unsqueeze(0)).item() for i in range(img1.size(0))]
    ms_ssim_db_values = [ms_ssim_to_db(value) for value in ms_ssim_values]
    return ms_ssim_db_values

def compress_and_evaluate_jpeg(images, quality, output_dir):
    size_list = []
    psnr_list = []
    ms_ssim_list = []
    compressed_images = []

    for i, image in enumerate(images):
        image_pil = transforms.ToPILImage()(image.cpu())
        output_path = os.path.join(output_dir, f"compressed_{i}_{quality}.jpg")
        image_pil.save(output_path, 'JPEG', quality=quality)
        
        compressed_image = Image.open(output_path)
        compressed_image = transforms.ToTensor()(compressed_image).unsqueeze(0)
        compressed_images.append(compressed_image)
        
        # Calculate size
        size_list.append(os.path.getsize(output_path) / 1024)
        
        # Calculate PSNR for the image
        psnr_value = psnr_batch(compressed_image, image.unsqueeze(0))
        psnr_list.extend(psnr_value)
        
        # Calculate MS-SSIM for the image
        ms_ssim_value = ms_ssim_batch(compressed_image, image.unsqueeze(0))
        ms_ssim_list.extend(ms_ssim_value)
    
    return np.mean(size_list), np.mean(psnr_list), np.mean(ms_ssim_list), compressed_images

def save_video_for_quality(compressed_images, quality, output_dir, frame_size):
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_video_path = os.path.join(output_dir, f'Q_{quality}_output_video.mp4')
    video_writer = cv2.VideoWriter(output_video_path, fourcc, 20, frame_size)

    for img in compressed_images:
        img_np = img.squeeze().cpu().numpy().transpose(1, 2, 0)
        img_bgr = (img_np * 255).astype(np.uint8)  # Convert to BGR for OpenCV
        video_writer.write(img_bgr)

    video_writer.release()

def eval_video_and_save(args):
    video_dataset = get_video_frames(args, 224)
    video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=1)

    qualities = range(1, 32, 3)
    results = []

    for quality in qualities:
        sizes = []
        psnrs = []
        ms_ssims = []
        all_compressed_images = []

        for idx, (images, _) in enumerate(video_loader):
            avg_size, avg_psnr, avg_ms_ssim, compressed_images = compress_and_evaluate_jpeg(images, quality, args.output_dir)
            sizes.append(avg_size)
            psnrs.append(avg_psnr)
            ms_ssims.append(avg_ms_ssim)
            all_compressed_images.extend(compressed_images)
        
        avg_size = np.mean(sizes)
        avg_psnr = np.mean(psnrs)
        avg_ms_ssim = np.mean(ms_ssims)

        # Convert image size to bits
        avg_size = avg_size * 8 * 1024
        
        # Calculate bpp (bits per pixel)
        bpp = avg_size / (224 * 224)
        
        results.append((float(bpp), float(avg_psnr), float(avg_ms_ssim)))
        print(f'Quality: {quality}, bpp: {bpp:.2f}, PSNR: {avg_psnr:.2f}dB, MS_SSIM: {avg_ms_ssim:.2f}dB')

        # Save video for this quality
        if video_dataset.frames:
            frame_size = (video_dataset.frames[0].size[0], video_dataset.frames[0].size[1])
            save_video_for_quality(all_compressed_images, quality, args.output_dir, frame_size)
        
    print(results)

if __name__ == "__main__":
    args = parse_args()
    eval_video_and_save(args)
