#!/usr/bin/env python3
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random
import json
import argparse
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
import io
import glob, pandas as pd
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from json import JSONEncoder
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from typing import Tuple
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
import pickle 
from transformers import AutoTokenizer
from scipy.special import betainc, betaln, loggamma, gammaincc
from scipy.stats import pearsonr, spearmanr 

trial_path = "/autoencoder"
if os.path.exists(trial_path) and trial_path not in sys.path:
    sys.path.append(trial_path)
    print(f"Added {trial_path} to Python path")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pathlib import Path
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
import rouge_score
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
try:
    from videoseal.utils.display import save_img
    from videoseal.utils import Timer
    from videoseal.evals.full import setup_model_from_checkpoint
    from videoseal.evals.metrics import bit_accuracy, pvalue, psnr as vs_psnr, ssim
    from videoseal.augmentation import Identity, JPEG, Crop, Resize
    from videoseal.modules.jnd import JND
    import torchvision
    
    VIDEOSEAL_AVAILABLE = True
    print("VideoSeal modules successfully imported")
except ImportError as e:
    print(f"Error importing VideoSeal modules: {e}")
    VIDEOSEAL_AVAILABLE = False

# Import from text_reconstruction/trial
try:
    from text_zipper import TextZipper
    from modulations import BPSKModulator
    LLMZIP_AVAILABLE = True
    print("TextZipper successfully imported")
except ImportError as e:
    print(f"Error importing TextZipper: {e}")
    LLMZIP_AVAILABLE = False

try:
    from data_finetune_ import WikiTextDataModule, Coco2017DataModule, PixmoCapDataModule
    from finetune_model import TextCompressor
    from peft import LoraConfig, get_peft_model
    DATA_MODULES_AVAILABLE = True
    print("Text compressor modules successfully imported")
except ImportError as e:
    print(f"Error importing text compressor modules: {e}")
    DATA_MODULES_AVAILABLE = False

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, np.bool_):  
            return bool(obj)
        return super(NumpyEncoder, self).default(obj)

# Download necessary NLTK data
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Configuration
MAX_LENGTH = 30

JSON_FILE_PATH = '/user/videoseal/watermark_comparison_results/nautilus_256_900_pixmo_pixmo_robust_0.01_best_contrast_1.5_results.json'


LOAD_FROM_JSON = False # Set to True to use the JSON file, False to use old method
LOAD_FROM_PKL = False
LLMZIP_AVAILABLE = False
DATA_MODULE = "pixmo" 
print (f"WORKING ON {DATA_MODULE}") 

if VIDEOSEAL_AVAILABLE:
    VIDEOSEAL_STRENGTHS = [1.2] #for 42 PSNR, 1.2 is enough#[  1.4, 1.2,1.0]
else:
    VIDEOSEAL_STRENGTHS = []

# Define semanticseal model configurations
SEMANTICSEAL_MODEL_CONFIGS = {
    "nautilus_256_900": { #GOATED ------------------------------------
        "encoder": "/user/videoseal/models/nautilus_900_enc.pth",
        "decoder": "/user/videoseal/models/nautilus_900_dec.pth",
        "img_size": 256,
        "watermark_power": 0.1555,
        "type": "semanticseal"
    },
}

# Define text compressor checkpoints

# If TextZipper is available, add LLMZip options
if LLMZIP_AVAILABLE :
    TEXT_COMPRESSOR_CHECKPOINTS= {"llmzip_opt125m": "facebook/opt-125m"}  
else : 
    TEXT_COMPRESSOR_CHECKPOINTS = {

    "pixmo_robust_0.01_best":"/user/autoencoder_ckpts/coco_ft/text-200epochs_restart--cosineAnnealing--cls--30tok-epoch=91-val_bleu=99.127-0.01.ckpt", 
    # "wikitext_robust_0.01":"/user/autoencoder_ckpts/wikitext_ft/text-200epochs_restart--cosineAnnealing--cls--30tok-epoch=95-val_bleu=97.838-0.01.ckpt",

}
# Define transformations for testing robustness
TRANSFORMATIONS_GLOBAL = {
    # "none": lambda img: img,
    "jpeg_75": lambda img: apply_jpeg_compression(img, quality=75),
    "rotation_5": lambda img: img.rotate(5, resample=Image.BICUBIC, expand=False),
    "crop_90": lambda img: apply_center_crop(img, percent=90),
    "gaussian_noise_10": lambda img: add_gaussian_noise(img, std=10),

    # === Weak Augmentations ===
    "identity": lambda img: img,
    "saturation_0.5": lambda img: adjust_saturation(img, factor=0.5),
    "horizontal_flip": lambda img: apply_horizontal_flip(img),
    "saturation_1.5": lambda img: adjust_saturation(img, factor=1.5),
    "contrast_0.5": lambda img: adjust_contrast(img, factor=0.5),
    "brightness_0.5": lambda img: adjust_brightness(img, factor=0.5),
    
    # === Moderate Augmentations ===
    "gaussian_blur_5": lambda img: apply_gaussian_blur(img, radius=5),
    "rotate_90": lambda img: img.rotate(90, resample=Image.BICUBIC, expand=False),
    "hue_0.1": lambda img: adjust_hue(img, factor=0.1),
    "hue_-0.1": lambda img: adjust_hue(img, factor=-0.1),
    "perspective_0.3": lambda img: apply_perspective_transform(img, distortion_scale=0.3),
    "jpeg_80": lambda img: apply_jpeg_compression(img, quality=80),

    # === Strong Augmentations ===
    "contrast_1.5": lambda img: adjust_contrast(img, factor=1.5),
    "brightness_1.5": lambda img: adjust_brightness(img, factor=1.5),
    "jpeg_70": lambda img: apply_jpeg_compression(img, quality=70),
    "rotate_10": lambda img: img.rotate(10, resample=Image.BICUBIC, expand=False),
    "perspective_0.5": lambda img: apply_perspective_transform(img, distortion_scale=0.5),
    "jpeg_60": lambda img: apply_jpeg_compression(img, quality=60),
    "jpeg_50": lambda img: apply_jpeg_compression(img, quality=50),
    "gaussian_blur_17": lambda img: apply_gaussian_blur(img, radius=17),
    "jpeg_40": lambda img: apply_jpeg_compression(img, quality=40),
}

# Helper functions for transformations
def apply_jpeg_compression(img, quality=75):
    """Apply JPEG compression to an image"""
    buffer = io.BytesIO()
    img.save(buffer, format="JPEG", quality=quality)
    buffer.seek(0)
    return Image.open(buffer).convert('RGB')

def add_gaussian_noise(img, std=10):
    """Add Gaussian noise to an image"""
    img_array = np.array(img).astype(np.float32)
    noise = np.random.normal(0, std, img_array.shape)
    noisy_img = np.clip(img_array + noise, 0, 255).astype(np.uint8)
    return Image.fromarray(noisy_img)

def apply_center_crop(img, percent=90):
    """Apply center crop to an image"""
    width, height = img.size
    new_width = int(width * percent / 100)
    new_height = int(height * percent / 100)
    left = (width - new_width) // 2
    top = (height - new_height) // 2
    right = left + new_width
    bottom = top + new_height
    cropped = img.crop((left, top, right, bottom))
    return cropped.resize((width, height), Image.LANCZOS)

def apply_jpeg_compression(img: Image.Image, quality: int) -> Image.Image:
    """Applies JPEG compression and returns the decompressed image."""
    buffer = io.BytesIO()
    img.save(buffer, format="JPEG", quality=quality)
    buffer.seek(0)
    return Image.open(buffer)

def adjust_saturation(img: Image.Image, factor: float) -> Image.Image:
    """Adjusts the saturation of an image."""
    enhancer = ImageEnhance.Color(img)
    return enhancer.enhance(factor)

def apply_horizontal_flip(img: Image.Image) -> Image.Image:
    """Applies a horizontal flip to the image."""
    return ImageOps.mirror(img)

def adjust_contrast(img: Image.Image, factor: float) -> Image.Image:
    """Adjusts the contrast of an image."""
    enhancer = ImageEnhance.Contrast(img)
    return enhancer.enhance(factor)

def adjust_brightness(img: Image.Image, factor: float) -> Image.Image:
    """Adjusts the brightness of an image."""
    enhancer = ImageEnhance.Brightness(img)
    return enhancer.enhance(factor)

def apply_gaussian_blur(img: Image.Image, radius: int) -> Image.Image:
    """Applies Gaussian blur to an image."""

    return img.filter(ImageFilter.GaussianBlur(radius=radius / 2.0))

def adjust_hue(img: Image.Image, factor: float) -> Image.Image:
    """Adjusts the hue of an image. factor is between -0.5 and 0.5."""
    return F.adjust_hue(img, hue_factor=factor)

def apply_perspective_transform(img: Image.Image, distortion_scale: float) -> Image.Image:
    """
    Applies a perspective transform.
    Note: This is a simplified implementation. The exact points would depend on the
    original paper's implementation.
    """
    width, height = img.size
    half_height = height // 2
    half_width = width // 2
    
    # Calculate the magnitude of the distortion
    magnitude = int(distortion_scale * half_width)

    # Define the start and end points for the perspective transform
    startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
    endpoints = [
        (magnitude, magnitude),
        (width - 1 - magnitude, magnitude),
        (width - 1, height - 1),
        (0, height - 1)
    ]
    return F.perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC)


# Place this near your other helper classes/functions
from torch.utils.data import Dataset, DataLoader

class PairedImageTextDataset(Dataset):
    """A PyTorch Dataset to pair images and texts for batching."""
    def __init__(self, images, texts, tokenizer, max_length):
        self.images = images
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Define image transform
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Get text and tokenize it
        text = self.texts[idx]
        
        # Choose a random image for this text
        # Using modulo to ensure we always have an image
        img = self.images[idx % len(self.images)].copy()
        
        # For simplicity, we'll apply PIL transforms here.
        # More advanced: convert transforms to operate on tensors.
        
        return img, text

def collate_fn(batch):
    """Custom collate function to handle PIL images and texts."""
    images, texts = zip(*batch)
    return list(images), list(texts) # Return lists of PIL images and strings


def load_and_reconstruct_texts_from_pkl(pkl_file_path: str, tokenizer_path="answerdotai/ModernBERT-base", num_texts: int = None):
    """
    Loads token IDs from a PKL file and reconstructs the original texts using a tokenizer.
    This works on PKL files that only contain an 'input_ids' key.

    Args:
        pkl_file_path (str): The path to the .pkl file.
        tokenizer: The tokenizer object (e.g., from Hugging Face) used for decoding.
        num_texts (int, optional): The maximum number of texts to load and reconstruct.
                                   If None, all texts are processed. Defaults to None.

    Returns:
        list: A list of the reconstructed text strings.
    """
    print(f"Attempting to load and reconstruct texts from '{pkl_file_path}'...")

    tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path,
            local_files_only=True,
            trust_remote_code=False
    )
    try:
        with open(pkl_file_path, 'rb') as f:
            data = pickle.load(f)

        # Your 'analyze_and_plot' shows the data is a dict with this key
        all_input_ids = data.get('input_ids')

        if all_input_ids is None:
            print(f"Error: Key 'input_ids' not found in '{pkl_file_path}'. Cannot reconstruct text.")
            return []

        # Limit the number of texts if requested
        ids_to_process = all_input_ids[:num_texts] if num_texts is not None else all_input_ids

        # Reconstruct each text by decoding the token IDs
        reconstructed_texts = []
        for token_ids in ids_to_process:
            # Use skip_special_tokens=True to automatically handle padding, etc.
            text = tokenizer.decode(token_ids, skip_special_tokens=True)
            reconstructed_texts.append(text)
        
        print(f"Successfully reconstructed {len(reconstructed_texts)} texts from '{pkl_file_path}'")
        return reconstructed_texts

    except FileNotFoundError:
        print(f"Error: The file '{pkl_file_path}' was not found.")
        return []
    except pickle.UnpicklingError:
        print(f"Error: Failed to unpickle '{pkl_file_path}'. File may be corrupt.")
        return []
    except KeyError:
        print(f"Error: The file '{pkl_file_path}' does not have the expected 'input_ids' key.")
        return []
    except Exception as e:
        print(f"An unexpected error occurred during reconstruction: {e}")
        return []

def load_texts_from_json(json_file_path, num_texts=None):
    """
    Loads texts from the 'original' key in a specified JSON file.

    Args:
        json_file_path (str): The path to the JSON file.
        num_texts (int, optional): The maximum number of texts to load. 
                                   If None, all texts are loaded. Defaults to None.

    Returns:
        list: A list of the loaded original texts.
    """
    test_texts = []
    print(f"Attempting to load texts from '{json_file_path}'...")

    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        # The data is a dictionary, and the texts are in the 'texts' key
        all_text_objects = data.get('texts', [])
        if not all_text_objects:
            print("Warning: The 'texts' key was not found or is empty in the JSON file.")
            return []

        for text_object in all_text_objects:
            # Stop if we have loaded the desired number of texts
            if num_texts is not None and len(test_texts) >= num_texts:
                break
            
            # Get the text from the "original" key
            original_text = text_object.get('original')

            # Ensure the text is not None and not just whitespace before adding
            if original_text and original_text.strip():
                test_texts.append(original_text)

        print(f"Successfully loaded {len(test_texts)} texts from {json_file_path}")
        return test_texts

    except FileNotFoundError:
        print(f"Error: The file '{json_file_path}' was not found.")
        return []
    except json.JSONDecodeError:
        print(f"Error: The file '{json_file_path}' is not a valid JSON file.")
        return []
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return []
        
def compute_bleu(reference, candidate):
    """Compute BLEU score between reference and candidate strings."""
    
    # Tokenize the texts
    reference_tokens = nltk.word_tokenize(reference.lower())
    candidate_tokens = nltk.word_tokenize(candidate.lower())
    
    # BLEU score with smoothing to handle cases where there are no n-gram overlaps
    smoothie = SmoothingFunction().method1
    
    # Calculate BLEU-1, BLEU-2, BLEU-3, and BLEU-4 scores
    weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
    bleu_scores = []
    
    for weight in weights:
        score = sentence_bleu([reference_tokens], candidate_tokens, weights=weight, smoothing_function=smoothie)
        bleu_scores.append(score)
    
    return bleu_scores

def compute_psnr(A, B):
    """Compute PSNR between two images"""
    A = np.asarray(A)
    B = np.asarray(B)
    assert(A.size == B.size)
    size = np.prod(A.shape)
    mse = np.sum((A - B)**2) / size
    return 10*np.log10(255*255 / mse)
    
def text_to_bits(text: str, text_zipper, max_bits: int = 256, max_length=30,truncate=True) -> Tuple[torch.Tensor, int]:
    """
    Convert text to a bit representation using LLMZip.
    
    Args:
        text: Input text to convert
        text_zipper: LLMZip text zipper instance
        max_bits: Maximum number of bits to use
        max_length: Maximum text length to encode
        
    Returns:
        Tuple containing:
        - Tensor of bits (0s and 1s)
        - Original bit length before truncation
    """
    # Encode text using LLMZip
    bitstream = io.BytesIO()
    H = text_zipper.encode(bitstream, text, max_length=max_length)
    data = bitstream.getvalue()
    
    # Convert bytes to bits
    bit_array = []
    for byte in data:
        # Convert each byte to 8 bits
        for i in range(7, -1, -1):
            bit_array.append((byte >> i) & 1)
    
    original_length = len(bit_array)
    # Truncate or pad to max_bits
    if len(bit_array) > max_bits:
        if truncate:
            bit_array = bit_array[:max_bits]
            print("truncated")
    else:
        bit_array = bit_array + [0] * (max_bits - len(bit_array))
        
    return torch.tensor(bit_array, dtype=torch.float32), original_length

def bits_to_text(bits: torch.Tensor, text_zipper, max_length=30) -> str:
    """
    Convert bit representation back to text using LLMZip.
    
    Args:
        bits: Tensor containing bit values
        text_zipper: LLMZip text zipper instance
        max_length: Maximum text length to decode
        
    Returns:
        Reconstructed text
    """
    bits = bits.int().cpu().numpy()
    # Convert bits to bytes
    byte_data = bytearray()
    for i in range(0, len(bits), 8):
        if i + 8 <= len(bits):
            byte = 0
            for j in range(8):
                byte = (byte << 1) | bits[i + j]
            byte_data.append(byte)
    
    # Decode using LLMZip
    bitstream = io.BytesIO(bytes(byte_data))
    decoded_text = text_zipper.decode(bitstream, max_length=max_length)
    
    # Return first line of decoded text
    return decoded_text.split("\n")[0]

def bytes_to_text(bits: torch.Tensor, text_zipper, max_length=30) -> str:
    """
    Convert bit representation back to text using LLMZip.
    
    Args:
        bits: Tensor containing bit values
        text_zipper: LLMZip text zipper instance
        max_length: Maximum text length to decode
        
    Returns:
        Reconstructed text
    """

    # exit()
    # Decode using LLMZip
    bitstream = io.BytesIO(bytes(bits))
    decoded_text = text_zipper.decode(bitstream, max_length=max_length)
    
    # Return first line of decoded text
    return decoded_text.split("\n")[0]


# Load COCO validation images
def load_coco_val_images(
    coco_val_dir: str = "/user/Paper2/coco_dataset/val2017", 
    image_size: int = 256,
    num_images: int = 10
):
    """Loads images from the COCO validation set using torchvision transforms"""
    import os
    import random
    from PIL import Image, ImageDraw
    from torchvision import transforms
    
    # Create pattern image as fallback
    def create_pattern_image(size):
        img = Image.new('RGB', (size, size), color='white')
        draw = ImageDraw.Draw(img)
        draw.rectangle([size//4, size//4, 3*size//4, 3*size//4], fill='blue')
        draw.ellipse([size//3, size//3, 2*size//3, 2*size//3], fill='red')
        return img
    
    # Define the transform
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
    ])
    
    # Validate directory exists
    if not os.path.exists(coco_val_dir):
        print(f"Warning: Directory not found: {coco_val_dir}")
        return [create_pattern_image(image_size) for _ in range(num_images)]
    
    # Get all image files
    valid_extensions = ['.jpg', '.jpeg', '.png']
    image_files = []
    for f in os.listdir(coco_val_dir):
        if os.path.isfile(os.path.join(coco_val_dir, f)) and any(f.lower().endswith(ext) for ext in valid_extensions):
            image_files.append(os.path.join(coco_val_dir, f))
    
    if len(image_files) == 0:
        print(f"Warning: No valid image files found in {coco_val_dir}")
        return [create_pattern_image(image_size) for _ in range(num_images)]
    
    # Shuffle and limit to requested number
    random.shuffle(image_files)
    image_files = image_files[:num_images]
    
    # Load and process images
    images = []
    for img_path in image_files:
        try:
            img = Image.open(img_path).convert('RGB')
            # Apply transform instead of manual resize
            img = transform(img)
            images.append(img)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
    
    # If we couldn't load enough images, pad with generated ones
    while len(images) < num_images:
        images.append(create_pattern_image(image_size))
    
    return images

class LLMZipTextCompressor:
    def __init__(self, model_name, max_length=30, cache_dir="./cache", enable_disk_cache=True):
        self.max_length = max_length
        self.enable_disk_cache = enable_disk_cache
        
        if not LLMZIP_AVAILABLE:
            raise ImportError("TextZipper not available")
            
        # Initialize LLMZip
        self.text_zipper = TextZipper(modelname=model_name)
        self.text_zipper.model.to(device)
        
        # Cache setup
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        
        # In-memory cache for frequently accessed items
        self.encode_cache = {}
        self.decode_cache = {}
        
    def _get_cache_path(self, cache_key, cache_type="encode"):
        """Get file path for cache entry"""
        return self.cache_dir / f"{cache_type}_{cache_key}.pkl"
    
    def _load_from_disk(self, cache_key, cache_type="encode"):
        """Load cached result from disk"""
        cache_path = self._get_cache_path(cache_key, cache_type)
        if cache_path.exists():
            try:
                with open(cache_path, 'rb') as f:
                    return pickle.load(f)
            except:
                # If loading fails, remove corrupted cache file
                cache_path.unlink(missing_ok=True)
        return None
    
    def _save_to_disk(self, cache_key, result, cache_type="encode"):
        """Save result to disk cache"""
        if not self.enable_disk_cache:
            return
            
        cache_path = self._get_cache_path(cache_key, cache_type)
        try:
            with open(cache_path, 'wb') as f:
                pickle.dump(result, f)
        except:
            pass  # Silently fail if can't save to disk
    
    def encode(self, text, truncate=True):
        """Convert text to a bit representation using LLMZip with caching"""
        cache_key = self._get_text_hash(text) + f"_{truncate}"
        
        # Check in-memory cache first
        if cache_key in self.encode_cache:
            return self.encode_cache[cache_key].clone()
        
        # Check disk cache
        if self.enable_disk_cache:
            cached_result = self._load_from_disk(cache_key, "encode")
            if cached_result is not None:
                self.encode_cache[cache_key] = cached_result
                return cached_result.clone()
        
        # Compute result (original logic)
        bitstream = io.BytesIO()
        _ = self.text_zipper.encode(bitstream, text, max_length=self.max_length)
        data = bitstream.getvalue()
        
        bit_array = []
        for byte in data:
            for i in range(7, -1, -1):
                bit_array.append((byte >> i) & 1)
        
        max_bits = 256
        if len(bit_array) > max_bits:
            if truncate:
                bit_array = bit_array[:max_bits]
        else:
            bit_array = bit_array + [0] * (max_bits - len(bit_array))
        
        result = torch.tensor(bit_array, dtype=torch.float32).unsqueeze(0).to(device)
        
        # Cache the result
        self.encode_cache[cache_key] = result.clone()
        self._save_to_disk(cache_key, result, "encode")
        
        return result
    
    def clear_cache(self):
        """Clear all caches"""
        self.encode_cache.clear()
        self.decode_cache.clear()
        
        if self.enable_disk_cache:
            for cache_file in self.cache_dir.glob("*.pkl"):
                cache_file.unlink()

class FineTunedTextCompressor:
    def __init__(self, checkpoint_path, max_length=30):
        self.max_length = max_length
        
        if not DATA_MODULES_AVAILABLE:
            raise ImportError("Required modules not available")

        self.data_module = Coco2017DataModule(
            batch_size=1,
            max_length=self.max_length,
            num_workers=1
        )
        # # Initialize data module to get vocab_size
        # self.data_module = WikiTextDataModule(
        #     batch_size=1,
        #     max_length=self.max_length,
        #     num_workers=1
        # )
        self.data_module.setup("test")
        
        # Initialize model
        self.model = TextCompressor(
            vocab_size=self.data_module.vocab_size,
            latent_dim=256,
            hidden_dim=512, #768
            num_layers=10, #8
            num_heads=8,
            dropout=0.0,
            pooling_strategy="cls", # <-- "mean" or "cls"
            teacher_forcing_start_ratio=0.0,
            teacher_forcing_end_ratio=0.0,
            
            noise_sigma = None,
            tokenizer=self.data_module.tokenizer,
            max_length=self.max_length
        ).eval()
        
        # Apply LoRA
        peft_config = LoraConfig(
            target_modules=["attn.Wqkv", "attn.Wo", "mlp.Wi", "mlp.Wo"],
            r=32, #16
            lora_alpha=64 #32
        )
        self.model.modern_bert = get_peft_model(self.model.modern_bert, peft_config)
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['state_dict'])
        
        self.model.eval()
        self.model.to(device)

    
    def encode(self, text):
        encoded = self.data_module.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        encoded = {k: v.to(device) for k, v in encoded.items()}
        
        with torch.no_grad():
            z = self.model.encode(
                encoded['input_ids'],
                encoded['attention_mask'].bool(),
            )
        return z

    def decode(self, z):
        with torch.no_grad():
            output_ids = self.model.generate(z, max_length=self.max_length)
            output_texts = self.data_module.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        return output_texts

# semanticseal model helper functions
def resize_and_pad_to_numpy(img, target=128):
    """Resize and pad an image for semanticseal models"""
    size = img.size
    ar = size[0] / size[1]

    if ar > 1:
        resize = (target, int(target / ar))
    else:
        resize = (int(target * ar), target)
    img = img.resize(resize, Image.LANCZOS).convert('RGB')
       
    # pad if needed, centering to avoid border artifacts if we can
    img = np.asarray(img)
    pw = (target - resize[0]) // 2
    ph = (target - resize[1]) // 2
    padding = ((ph, target - resize[1] - ph), (pw, target - resize[0] - pw), (0,0))
    img = np.pad(img, padding, mode="reflect")
    return img, ar

# Define the transformation for semanticseal models
transform_imnet = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class semanticsealWatermarker:
    """Class for handling semanticseal watermarking models"""
    def __init__(self, model_config):
        self.model_config = model_config
        self.img_size = model_config["img_size"]
        self.watermark_power = model_config["watermark_power"]
        
        # Load models
        self.encoder = torch.jit.load(model_config["encoder"]).to(device).eval()
        self.decoder = torch.jit.load(model_config["decoder"]).to(device).eval()
        
        print(f"Initialized semanticseal model with watermark power {self.watermark_power}")

    def embed(self, img, msg,overwrite_power=None):
        # Normalize the message
        if isinstance(msg, np.ndarray):
            msg = msg / np.sqrt(np.dot(msg, msg))
            msg = torch.tensor(msg, dtype=torch.float32).unsqueeze(0).to(device)
        else:
            msg = msg / torch.sqrt(torch.sum(msg * msg))
            
        # Process the image
        imgo, ar = resize_and_pad_to_numpy(img, target=self.img_size)
        imgt = transform_imnet(imgo).unsqueeze(0).to(device)
        
        # Embed the watermark
        with torch.no_grad():
            imgw = self.encoder(imgt, msg)

        # Convert back to PIL format
        y = (imgw - imgt).cpu().numpy() * 0.5 + 0.5
        mimg = Image.fromarray((y[0].transpose((1,2,0)) * 255.0).clip(0,255.0).astype(np.uint8))
        
        # Resize and remove padding
        size = max(img.size)
        mimg = mimg.resize((size, size), Image.LANCZOS)
        mimg = np.asarray(mimg)
        pw = (size - img.size[0]) // 2
        ph = (size - img.size[1]) // 2
        mimg = mimg[ph:ph+img.size[1],pw:pw+img.size[0]]

        mimg = (mimg / 255.0 - 0.5)

        if overwrite_power is not None:
            y = np.asarray(img) + mimg * overwrite_power * 255.0
        else:
            # Add to original with watermark power
            y = np.asarray(img) + mimg * self.watermark_power * 255.0
           
        imgw = Image.fromarray(y.clip(0,255.0).astype(np.uint8))
        return imgw
    
    def detect(self, img):
        # Isotropic downscale
        img, ar = resize_and_pad_to_numpy(img, target=self.img_size)
        img = transform_imnet(img).unsqueeze(0).to(device)

        # Detect the watermark
        with torch.no_grad():
            dec = self.decoder(img)[0].cpu().numpy()
            # Torchscript output is unnormalized
            dec = dec / np.sqrt(np.dot(dec, dec))

        return dec
    
    def detect_batch(self, img_batch_pil: list) -> np.ndarray:
        """
        Performs batched watermark detection on a list of PIL images.
        This ensures the pre-processing is identical to the serial version.
        """
        # 1. Use the EXACT same pre-processing helper that the embed step uses.
        # This is the key to consistency. It takes a list of PIL images
        # and returns a correctly pre-processed and normalized tensor batch.
        img_batch_tensor = preprocess_for_semanticseal_batch(img_batch_pil, self.img_size, device)
        
        with torch.no_grad():
            # 2. Get the batched output from the decoder model.
            # The decoder itself is already batch-compatible.
            decoded_batch_tensor = self.decoder(img_batch_tensor)  # Shape: [N, D]
            
            # 3. Perform batched normalization of the output vectors.
            # This is the batched equivalent of `dec / np.sqrt(np.dot(dec, dec))`.
            norm = torch.linalg.norm(decoded_batch_tensor, ord=2, dim=1, keepdim=True)
            normalized_batch = decoded_batch_tensor / (norm + 1e-8)  # Add epsilon for numerical stability
            
        return normalized_batch.cpu().numpy()
        
def evaluate_videoseal_strength(
    watermark_strength, 
    checkpoint_name,
    transformation_name,
    num_test_images=10,
    num_test_texts=20,
    output_dir="results"
):
    """Evaluate VideoSeal with a specific watermark strength"""
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    print(f"\n\n{'='*80}")
    print(f"Evaluating VideoSeal with:")
    print(f"Watermark strength: {watermark_strength}")
    print(f"Text compressor: {checkpoint_name}")
    print(f"Transformation: {transformation_name}")
    print(f"{'='*80}\n")
    checkpoint_name = "llmzip_opt125m"
    # Output prefix for files
    model_name = f"videoseal_strength_{watermark_strength:.1f}"
    output_prefix = f"{output_dir}/{model_name}_{checkpoint_name}_{DATA_MODULE}_{transformation_name}"
    
    # Get transformation function
    transform_fn = TRANSFORMATIONS_GLOBAL[transformation_name]
    
    # Initialize VideoSeal model
    try:
        print("Initializing VideoSeal model...")
        model = setup_model_from_checkpoint('videoseal')
        model.eval()
        model.compile()
        model.to(device)
        
        # Adjust watermark strength
        model.blender.scaling_w *= watermark_strength
        print(f"VideoSeal model initialized with strength {watermark_strength}")
        
        # For converting between PIL and tensor
        to_tensor = torchvision.transforms.ToTensor()
        to_pil = torchvision.transforms.ToPILImage()
    except Exception as e:
        print(f"Error initializing VideoSeal model: {e}")
        return None
    
    # Initialize text compressor
    try:
        print("Initializing text compressor...")
        if "llmzip" in checkpoint_name and LLMZIP_AVAILABLE:
            adapter_path = None
            #adapter_path = f"/autoencoder/finetune_llm_results/{DATA_MODULE}"
            print(f"using {adapter_path}")
            # else:
            #     adapter_path = TEXT_COMPRESSOR_CHECKPOINTS.get("adapter_path")
                
            text_zipper = TextZipper(
                modelname=TEXT_COMPRESSOR_CHECKPOINTS[checkpoint_name],
                adapter_path=adapter_path
            )
            text_zipper.model.to(device)
            print(f"Initialized LLMZip text compressor")
            
            # Use the text_to_bits and bits_to_text functions directly
            def encode_text(text,max_bits=256):
                bit_msg, length = text_to_bits(text, text_zipper,max_bits=max_bits, max_length=MAX_LENGTH)
                return bit_msg.unsqueeze(0).to(device), length
            
            def decode_bits(bits):
                return bits_to_text(bits, text_zipper, max_length=MAX_LENGTH)
                
        elif DATA_MODULES_AVAILABLE:
            compressor = FineTunedTextCompressor(
                TEXT_COMPRESSOR_CHECKPOINTS[checkpoint_name],
                max_length=MAX_LENGTH
            )
            print(f"Initialized fine-tuned text compressor")
            
            # Wrapper functions to maintain consistent interface
            def encode_text(text):
                return compressor.encode(text)
            
            def decode_bits(bits):
                return compressor.decode(bits.unsqueeze(0).to(device) if bits.dim() == 1 else bits)
        else:
            print("No suitable text compressor available")
            return None
    except Exception as e:
        print(f"Error initializing text compressor: {e}")
        return None
    
    # Load test images
    print("Loading test images...")
    try:
        test_images = load_coco_val_images(
            num_images=num_test_images,
            image_size=256  # VideoSeal works with 256x256 images
        )
        print(f"Loaded {len(test_images)} test images")
    except Exception as e:
        print(f"Error loading test images: {e}")
        return None
    
    # Load test texts
    print("Loading test texts...")
    test_texts = []
    if LOAD_FROM_JSON:
        test_texts = load_texts_from_json(JSON_FILE_PATH, num_texts=num_test_texts)
    elif LOAD_FROM_PKL:
        test_texts = load_and_reconstruct_texts_from_pkl(PKL_FILE_PATH,num_texts=num_test_texts)
    else:
        # Try loading from WikiText data module if available
        if DATA_MODULES_AVAILABLE:
            try:
                if DATA_MODULE == "wikitext":
                    data_module = WikiTextDataModule(batch_size=1, max_length=MAX_LENGTH)
                elif DATA_MODULE == "coco":
                    data_module = Coco2017DataModule(batch_size=1, max_length=MAX_LENGTH)
                else:
                    data_module = PixmoCapDataModule(batch_size=1, max_length=MAX_LENGTH)
                    
                data_module.setup('test')
                test_loader = data_module.test_dataloader()
                print(f"loading up to {MAX_LENGTH} tokens")
                i = 0
                for batch in tqdm(test_loader,total=2000):
                    if i >= num_test_texts:
                        break
                    if len(batch['input_ids'][0])<MAX_LENGTH:
                        continue
                        
                    text = data_module.tokenizer.decode(batch['input_ids'][0][:MAX_LENGTH], skip_special_tokens=True)
                    bit_msg_tensor_check, length_check = encode_text(text,max_bits=500) #just to measure length of the GT
                    
                    if length_check <=256:
                        continue
                        
                    if text.strip():  # Skip empty texts
                        test_texts.append(text)
                        i+=1
                        
                print(f"Loaded {len(test_texts)} texts from {DATA_MODULE}")
            except Exception as e:
                print(f"Error loading WikiText data: {e}")

    
    # Initialize metrics
    results = {
        'exact_matches': 0,
        'total_samples': 0,
        'psnr_values': [],
        'bleu1_scores': [],
        'bleu4_scores': [],
        'texts': []
    }

    print("Starting evaluation...")
    oob_counter = 0
    for i, text in enumerate(tqdm(test_texts)):
        if not text.strip():
            continue
        
        img = random.choice(test_images).copy()

        bit_msg_tensor, length = encode_text(text)
        oob_counter+=1
        # if length <256:
        #     oob_counter-=1
        #     continue
        reference_decoded_text = decode_bits(bit_msg_tensor[0])
        img_tensor = to_tensor(img).unsqueeze(0).float().to(device)

        # Initialize variables for binary search
        low = 0.0
        high = 2.0  # You can adjust the range as needed
        best_power = watermark_strength
        best_diff = float('inf')
        best_img_w = None  # Initialize this variable
        
        # Fixed configuration values
        target_psnr = 42  # Removed trailing comma
        tolerance = 0.2
        max_iterations = 12
        
        for _ in range(max_iterations):
            # Adjust the watermark strength
            model.blender.scaling_w = best_power

            # Embed watermark
            with torch.no_grad():
                original_get_random_msg = model.embedder.unet.msg_processor.get_random_msg
                model.embedder.unet.msg_processor.get_random_msg = lambda bsz, nb_repetitions: bit_msg_tensor
                outputs = model.embed(img_tensor, is_video=False, lowres_attenuation=True)
                model.embedder.unet.msg_processor.get_random_msg = original_get_random_msg

                imgs_w = outputs["imgs_w"]

            # Calculate PSNR directly
            psnr_result = vs_psnr(imgs_w, img_tensor)
            
            # If the result is a tensor with a single element, extract it using .item()
            if psnr_result.numel() == 1:
                current_psnr = psnr_result.item()
            else:
                # If the result is a tensor with multiple elements, compute the mean
                current_psnr = psnr_result.mean().item()
            
            diff = abs(current_psnr - target_psnr)

            # Update best result if this is closer to target
            if diff < best_diff:
                best_diff = diff
                best_img_w = imgs_w.clone()  # Clone to avoid reference issues

            # Check if we're within tolerance
            if diff <= tolerance:
                break

            # Binary search adjustment
            if current_psnr > target_psnr:
                low = best_power
            else:
                high = best_power

            best_power = (low + high) / 2

        # Use the best watermarked image found (fallback to last if best_img_w is still None)
        if best_img_w is not None:
            imgs_w = best_img_w
        # If best_img_w is None, imgs_w already contains the last iteration result

        # Calculate PSNR for the best image
        psnr_value = vs_psnr(imgs_w, img_tensor)
        if psnr_value.numel() == 1:
            psnr_value = psnr_value.item()
        else:
            psnr_value = psnr_value.mean().item()
        
        results['psnr_values'].append(psnr_value)

        # Convert to PIL for transformations
        watermarked_img = to_pil(imgs_w[0].cpu())
        transformed_img = transform_fn(watermarked_img)
        transformed_tensor = to_tensor(transformed_img).unsqueeze(0).float().to(device)

        # Detect watermark and calculate metrics
        detect_outputs = model.detect(transformed_tensor, is_video=False)
        preds = detect_outputs["preds"]
        bit_preds = preds[:, 1:]
        bit_acc = bit_accuracy(bit_preds, bit_msg_tensor).item()
        pred_bits = (bit_preds > 0).float()
        detected_text = decode_bits(pred_bits[0])

        is_exact_match = (detected_text.strip() == text.strip())
        if is_exact_match:
            results['exact_matches'] += 1

        bleu_scores = compute_bleu(text, detected_text) #was reference_decoded_text before
        results['bleu1_scores'].append(bleu_scores[0])
        results['bleu4_scores'].append(bleu_scores[3])

        results['texts'].append({
            'original': text,
            'reference': reference_decoded_text,
            'detected': detected_text,
            'is_match': is_exact_match,
            'bleu1': bleu_scores[0],
            'bleu4': bleu_scores[3],
            'psnr': psnr_value,
            'bit_accuracy': bit_acc
        })

        results['total_samples'] += 1
    # Skip if no samples were processed
    if results['total_samples'] == 0:
        print("No samples were successfully processed")
        return None
    
    # Calculate summary metrics
    summary = {
        'model': f"videoseal_strength_{watermark_strength:.1f}",
        'checkpoint': checkpoint_name,
        'transformation': transformation_name,
        'watermark_strength': watermark_strength,
        'exact_match_rate': results['exact_matches'] / results['total_samples'] * 100,
        'avg_psnr': np.mean(results['psnr_values']),
        'avg_bleu1': np.mean(results['bleu1_scores']),
        'avg_bleu4': np.mean(results['bleu4_scores']),
        'total_samples': results['total_samples'],
        'oob_percentage': oob_counter/len(test_texts),
    }
    
    # Save results to JSON
    with open(f"{output_prefix}_results.json", 'w') as f:
        json.dump({
            'summary': summary,
            'texts': results['texts']  # Save first 10 examples
        }, f, indent=4, cls=NumpyEncoder)
    
    print(f"\nEvaluation complete for strength={watermark_strength}, checkpoint={checkpoint_name}, transform={transformation_name}")
    print(f"Exact match rate: {summary['exact_match_rate']:.2f}%")
    print(f"Average PSNR: {summary['avg_psnr']:.2f} dB")
    print(f"Average BLEU-4: {summary['avg_bleu4']:.4f}")
    
    return summary

def compute_psnr_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    Computes PSNR for a batch of images.
    Args:
        a, b: Batches of images as tensors, shape [N, C, H, W], value range [0, 255].
    Returns:
        A tensor of PSNR values, shape [N].
    """
    # Ensure tensors are float
    a = a.float()
    b = b.float()
    
    # Calculate MSE for each image in the batch
    # Keep the batch dimension, average over C, H, W
    mse = torch.mean((a - b) ** 2, dim=[1, 2, 3])
    
    # Handle the case where MSE is zero (perfect match)
    # Add a small epsilon to avoid log(0)
    psnr = 10 * torch.log10(255.0 * 255.0 / (mse + 1e-8))
    
    return psnr

def preprocess_for_semanticseal_batch(img_batch_pil: list, target_size: int, device: torch.device):
    """
    Preprocesses a batch of PIL images for semanticseal models, returning a tensor.
    This handles the per-image resize/pad logic from the original code.
    """
    img_tensors = []
    for img in img_batch_pil:
        # Using the original resize_and_pad_to_numpy logic
        imgo, _ = resize_and_pad_to_numpy(img, target=target_size)
        # Using the original transform_imnet logic
        imgt = transform_imnet(imgo).unsqueeze(0)
        img_tensors.append(imgt)
        
    return torch.cat(img_tensors, dim=0).to(device)

def embed_watermark_batch(
    watermarker,
    img_batch_pil: list,
    original_tensors_for_encoder: torch.Tensor,
    msg_batch: torch.Tensor,
    power_batch: torch.Tensor
) -> torch.Tensor:
    """
    Embeds watermarks in a batch using a hybrid approach to strictly match the original serial logic.
    - GPU is used for the batched encoder pass.
    - CPU is used in a loop for the per-image PIL/Numpy residual logic.
    """
    # 1. BATCHED GPU PART: Run the encoder on the entire batch at once.
    with torch.no_grad():
        # This is the main performance win.
        # original_tensors_for_encoder shape: [N, C, H_resized, W_resized]
        # watermarked_tensors_from_encoder shape: [N, C, H_resized, W_resized]
        watermarked_tensors_from_encoder = watermarker.encoder(original_tensors_for_encoder, msg_batch)

    # Prepare a list to collect final watermarked image tensors
    final_watermarked_tensors_list = []

    # 2. HYBRID CPU LOOP: Iterate through the batch to apply the exact original PIL logic
    for i in range(len(img_batch_pil)):
        # --- Start of logic copied EXACTLY from the original `embed` function ---
        
        # Get the i-th item from our batches
        original_pil_image = img_batch_pil[i]
        original_encoder_tensor = original_tensors_for_encoder[i:i+1] # Keep dim for indexing
        watermarked_encoder_tensor = watermarked_tensors_from_encoder[i:i+1]
        power = power_batch[i].item() # Get scalar power value

        # Step 3 from original: Calculate residual and convert to PIL
        residual_tensor = (watermarked_encoder_tensor - original_encoder_tensor).cpu().numpy() * 0.5 + 0.5
        residual_pil = Image.fromarray((residual_tensor[0].transpose((1,2,0)) * 255.0).clip(0,255.0).astype(np.uint8))

        # Step 4 from original: Resize and un-pad the *residual* PIL image
        size = max(original_pil_image.size)
        mimg_resized = residual_pil.resize((size, size), Image.LANCZOS)
        mimg_np = np.asarray(mimg_resized)
        pw = (size - original_pil_image.size[0]) // 2
        ph = (size - original_pil_image.size[1]) // 2
        mimg_unpadded = mimg_np[ph:ph+original_pil_image.size[1], pw:pw+original_pil_image.size[0]]
        mimg_final_residual = (mimg_unpadded / 255.0 - 0.5)

        # Step 5 from original: Add residual to the *original* high-res PIL image
        y = np.asarray(original_pil_image) + mimg_final_residual * power * 255.0
        final_pil_image = Image.fromarray(y.clip(0,255.0).astype(np.uint8))
        
        # --- End of copied logic ---
        
        # 3. Convert the final PIL image back to a tensor for the PSNR calculation
        # The PSNR function needs a tensor in [C, H, W] format
        final_tensor = torch.from_numpy(np.array(final_pil_image).transpose(2, 0, 1))
        final_watermarked_tensors_list.append(final_tensor)

    # 4. STACKING: Combine the list of tensors into a single batch tensor
    # Ensure they are on the correct device for the PSNR function
    return torch.stack(final_watermarked_tensors_list, dim=0).to(original_tensors_for_encoder.device)

def find_target_power_batch(
    watermarker,
    img_batch_pil: list,
    msg_batch: torch.Tensor,
    target_psnr: float = 42.0,
    tolerance: float = 0.2,
    max_iterations: int = 10,
    device: torch.device = torch.device('cuda')
) -> tuple[list, torch.Tensor]:
    """
    Performs a parallel binary search using the HYBRID embedding logic.
    """
    batch_size = len(img_batch_pil)
    
    # 1. Preprocess images ONCE for the encoder. This is a key optimization.
    # img_batch_tensor_norm is the input to the encoder.
    img_batch_tensor_norm = preprocess_for_semanticseal_batch(img_batch_pil, watermarker.img_size, device)
    
    # Also create the [0, 255] version of the ORIGINAL PIL images for PSNR comparison
    original_pil_as_tensors_255 = torch.stack(
        [torch.from_numpy(np.array(p).transpose(2, 0, 1)) for p in img_batch_pil], dim=0
    ).to(device)

    msg_batch = msg_batch.to(device)

    # 2. Vectorize state variables (this remains the same)
    low_batch = torch.zeros(batch_size, device=device)
    high_batch = torch.ones(batch_size, device=device) * 2.0
    best_power_batch = torch.zeros_like(low_batch)
    best_diff_batch = torch.full_like(low_batch, float('inf'))
    active_mask = torch.ones(batch_size, dtype=torch.bool, device=device)

    best_watermarked_batch_255 = torch.zeros_like(original_pil_as_tensors_255)

    for i in range(max_iterations):
        if not active_mask.any():
            break
            
        current_power_batch = (low_batch + high_batch) / 2.0
        
        # 3. USE THE NEW HYBRID EMBED FUNCTION
        watermarked_batch_255 = embed_watermark_batch(
            watermarker,
            img_batch_pil,
            img_batch_tensor_norm,
            msg_batch,
            current_power_batch
        )
        
        # 4. Compute PSNR for the entire batch
        # Compare against the original PIL images, not the pre-processed ones
        current_psnr_batch = compute_psnr_batch(original_pil_as_tensors_255, watermarked_batch_255)
        
        # 5. Vectorized Logic to update state (this remains the same)
        diff_batch = torch.abs(current_psnr_batch - target_psnr)

        is_closer_mask = diff_batch < best_diff_batch
        update_best_mask = is_closer_mask & active_mask
        best_power_batch = torch.where(update_best_mask, current_power_batch, best_power_batch)
        best_diff_batch = torch.where(update_best_mask, diff_batch, best_diff_batch)
        
        # Store the best watermarked images found so far
        best_watermarked_batch_255[update_best_mask] = watermarked_batch_255[update_best_mask]

        converged_mask = (diff_batch <= tolerance) & active_mask
        active_mask[converged_mask] = False

        psnr_too_high_mask = current_psnr_batch > target_psnr
        update_low_mask = psnr_too_high_mask & active_mask
        update_high_mask = ~psnr_too_high_mask & active_mask
        low_batch[update_low_mask] = current_power_batch[update_low_mask]
        high_batch[update_high_mask] = current_power_batch[update_high_mask]

    # 6. Convert the final best tensor batch back to a list of PIL Images
    final_watermarked_imgs_pil = []
    for img_tensor in best_watermarked_batch_255:
        img_np = img_tensor.cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
        final_watermarked_imgs_pil.append(Image.fromarray(img_np))
        
    return final_watermarked_imgs_pil, best_power_batch

def find_target_power(watermarker, img, msg, target_psnr=42, tolerance=0.2, max_iterations=10):
    """Binary search to find watermark power that achieves target PSNR"""
    low = 0.0
    high = 1.0
    best_power = None
    best_diff = float('inf')
    best_img = None
    
    for _ in range(max_iterations):
        current_power = (low + high) / 2
        watermarked_img = watermarker.embed(img, msg, overwrite_power=current_power)
        current_psnr = compute_psnr(img, watermarked_img)
        
        diff = abs(current_psnr - target_psnr)
        
        # Update best result if this is closer to target
        if diff < best_diff:
            best_diff = diff
            best_power = current_power
            best_img = watermarked_img
        
        # Check if we're within tolerance
        if diff <= tolerance:
            return watermarked_img, current_power
            
        # Binary search adjustment
        if current_psnr > target_psnr:
            low = current_power
        else:
            high = current_power
    
    # Return best result found if we couldn't get within tolerance
    return best_img, best_power

def evaluate_semanticseal_model_batched(
    model_name,
    checkpoint_name,
    transformation_name,
    num_test_images=10,
    num_test_texts=20,
    output_dir="results",
    batch_size=32  # <-- Add batch_size argument
):
    """
    Evaluate a semanticseal model (f0c/nautilus/hide-r) using batch processing.
    """
    os.makedirs(output_dir, exist_ok=True)
    print(f"\n\n{'='*80}")
    print(f"BATCH EVALUATING semanticseal model:")
    print(f"Model: {model_name}, Compressor: {checkpoint_name}, Transform: {transformation_name}")
    print(f"Batch Size: {batch_size}")
    print(f"{'='*80}\n")
    
    # --- [UNCHANGED] Initialization of model, compressor, data loading ---
    output_prefix = f"{output_dir}/{model_name}_{DATA_MODULE}_{checkpoint_name}_{transformation_name}"
    transform_fn = TRANSFORMATIONS_GLOBAL[transformation_name]

    model_config = SEMANTICSEAL_MODEL_CONFIGS[model_name]
    watermarker = semanticsealWatermarker(model_config)
    
    if "llmzip" in checkpoint_name:
        # For simplicity, assuming LLMZipTextCompressor is used here
        compressor = LLMZipTextCompressor(TEXT_COMPRESSOR_CHECKPOINTS[checkpoint_name])
    else:
        compressor = FineTunedTextCompressor(
            TEXT_COMPRESSOR_CHECKPOINTS[checkpoint_name],
            max_length=MAX_LENGTH,
        )

    # Load images and texts
    img_size = model_config["img_size"]
    test_images = load_coco_val_images(num_images=num_test_images, image_size=img_size)
    
    test_texts = []
    if LOAD_FROM_JSON:
        test_texts = load_texts_from_json(JSON_FILE_PATH, num_texts=num_test_texts)
    elif LOAD_FROM_PKL:
        test_texts = load_and_reconstruct_texts_from_pkl(PKL_FILE_PATH,num_texts=num_test_texts)
    else:
        if DATA_MODULES_AVAILABLE:
            if DATA_MODULE == "wikitext":
                data_module = WikiTextDataModule(batch_size=1, max_length=MAX_LENGTH)
            if DATA_MODULE == "coco":
                data_module = Coco2017DataModule(batch_size=1, max_length=MAX_LENGTH)
            if "pixmo" in DATA_MODULE :
                data_module = PixmoCapDataModule(batch_size=1, max_length=MAX_LENGTH)
            data_module.setup('test')
            test_loader = data_module.test_dataloader()
            loaded_count = 0
            for batch in test_loader:
                if loaded_count >= num_test_texts: break
                text = data_module.tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True)
                if text.strip():
                    test_texts.append(text)
                    loaded_count += 1
            print(f"Loaded {len(test_texts)} texts from {DATA_MODULE}")

    # --- Create DataLoader for batching ---
    dataset = PairedImageTextDataset(test_images, test_texts, compressor.data_module.tokenizer, MAX_LENGTH)
    # Use collate_fn to handle PIL images
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=16, collate_fn=collate_fn)

    results = {
        'exact_matches': 0, 'total_samples': 0, 'psnr_values': [],
        'bleu1_scores': [], 'bleu4_scores': [], 'rouge_l_scores': [], 'texts': []
    }
    
    print("Starting batched evaluation...")
    for img_batch_pil, text_batch in tqdm(dataloader):
        
        # 1. Encode text batch
        # Note: Some encoders might not support batching directly. If so, a list comprehension is needed.
        # Assuming compressor.encode can handle a list of texts or we adapt it.
        # For this example, let's use a loop for encoding, but GPU calls will be batched.
        try:
            # Ideal case: a batched encode method
            latent_batch_tensor = compressor.encode(text_batch) # Assuming this is batched
 
        except:
            # Fallback: encode one-by-one and stack
            latent_list = [compressor.encode(text) for text in text_batch]
            latent_batch_tensor = torch.cat(latent_list, dim=0)

        # Decode reference texts from the initial encoding
        reference_texts = [compressor.decode(latent.unsqueeze(0)) for latent in latent_batch_tensor]
        
        # This is now a tensor of shape [batch_size, message_dim]
        msg_to_embed_batch = latent_batch_tensor.squeeze(1).cpu().numpy()

        # 2. Embed watermark on batch
        # We need a way to batch the `embed` call. Let's assume it can take a list of images.
        watermarked_imgs_pil = []
        original_imgs_pil = []

        watermarked_imgs_pil, used_powers_tensor = find_target_power_batch(
            watermarker,
            img_batch_pil,
            latent_batch_tensor, # msg_batch
            target_psnr=42.0,
            tolerance=0.2,
            max_iterations=10,
            device=device
        )
        
        # 3. Calculate PSNR for the batch
        psnr_batch = [compute_psnr(orig, wm) for orig, wm in zip(img_batch_pil, watermarked_imgs_pil)]
        results['psnr_values'].extend(psnr_batch)
        
        # 4. Apply transformations and detect on batch
        transformed_imgs_pil = [transform_fn(img) for img in watermarked_imgs_pil]
            
        detected_vectors_list = [watermarker.detect(img) for img in transformed_imgs_pil]
        detected_vectors_np = np.stack(detected_vectors_list, axis=0)        
        # The rest of the code remains the same
        detected_vectors_tensor = torch.tensor(detected_vectors_np, dtype=torch.float32).to(device)
        detected_texts = compressor.decode(detected_vectors_tensor) # Assumes batched decode is fixed

        # 5. Decode detected vectors on batch
        detected_vectors_tensor = torch.tensor(detected_vectors_np, dtype=torch.float32).to(device)
        try:
            # Ideal case: a batched decode method
            detected_texts = compressor.decode(detected_vectors_tensor)
        except:
             # Fallback: decode one-by-one
            detected_texts = [compressor.decode(vec.unsqueeze(0)) for vec in detected_vectors_tensor]
        # 6. Calculate metrics for the batch results
        for i in range(len(text_batch)):
            ref_text = reference_texts[i]
            det_text = detected_texts[i]
            
            is_exact_match = (det_text.strip().lower() == ref_text[0].strip().lower())
            if is_exact_match:
                results['exact_matches'] += 1
            
            bleu_scores = compute_bleu(ref_text[0], det_text)
            results['bleu1_scores'].append(bleu_scores[0])
            results['bleu4_scores'].append(bleu_scores[3])

            rouge_l_score = scorer.score(ref_text[0], det_text)['rougeL'].fmeasure
            results['rouge_l_scores'].append(rouge_l_score)
            
            results['texts'].append({
                'original': text_batch[i], 'reference': ref_text[0], 'detected': det_text,
                'is_match': is_exact_match, 'bleu4': bleu_scores[3], 'psnr': psnr_batch[i]
            })
            results['total_samples'] += 1
            
    # --- [UNCHANGED] Summary, saving results, etc. ---
    if results['total_samples'] == 0:
        print("No samples were successfully processed")
        return None

    summary = {
        'model': model_name,
        'compressor_config': checkpoint_name,
        'transformation': transformation_name,
        'exact_match_rate': (results['exact_matches'] / results['total_samples'] * 100),
        'avg_psnr': np.mean(results['psnr_values']),
        'avg_bleu4': np.mean(results['bleu4_scores']),
        'avg_rouge_l': np.mean(results['rouge_l_scores']),
        'total_samples': results['total_samples']
    }

    with open(f"{output_prefix}_results.json", 'w') as f:
        json.dump({'summary': summary, 'texts': results['texts']}, f, indent=4, cls=NumpyEncoder)

    print(f"\nEvaluation complete for model={model_name}, compressor={checkpoint_name}, transform={transformation_name}")
    print(f"Exact match rate: {summary['exact_match_rate']:.2f}%")
    print(f"Average PSNR: {summary['avg_psnr']:.2f} dB")
    print(f"Average BLEU-4: {summary['avg_bleu4']:.4f}")
    
    return summary
    
def evaluate_semanticseal_model(
    model_name,
    checkpoint_name,
    transformation_name,
    num_test_images=10,
    num_test_texts=20,
    output_dir="results"
):
    """Evaluate a semanticseal model (f0c/nautilus/hide-r)"""

    os.makedirs(output_dir, exist_ok=True)
    print(f"\n\n{'='*80}")
    print(f"Evaluating semanticseal model:")
    print(f"Model: {model_name}")
    print(f"Text compressor: {checkpoint_name}")
    print(f"Transformation: {transformation_name}")
    print(f"{'='*80}\n")
    output_prefix = f"{output_dir}/{model_name}_{DATA_MODULE}_{checkpoint_name}_{transformation_name}"
    transform_fn = TRANSFORMATIONS_GLOBAL[transformation_name]

    try:
        print("Initializing semanticseal model...")
        model_config = SEMANTICSEAL_MODEL_CONFIGS[model_name]
        watermarker = semanticsealWatermarker(model_config)
        message_dim = 256
        print(f"semanticseal model initialized: {model_name} (expects msg dim: {message_dim})")
    except Exception as e:
        print(f"Error initializing semanticseal model: {e}")
        return None

    compressor = None
    bpsk_modulator = None
    use_bpsk_modulation = False

    try:
        print("Initializing text compressor...")
        # Check if using LLMZip and if BPSK is requested for this combo
        if use_bpsk_modulation and LLMZIP_AVAILABLE:
            llm_model_path = TEXT_COMPRESSOR_CHECKPOINTS["llmzip_opt125m"]
            compressor = LLMZipTextCompressor(#was TextZipper beforehand, try to add cache mechanism here
                modelname=llm_model_path
            )
            print(f"Initialized LLMZip text compressor llmzip_opt125m")
            def encode_text(text):
                bit_msg,length = text_to_bits(text, compressor, max_length=MAX_LENGTH)
                return bit_msg.unsqueeze(0).to(device)
            
            def decode_bits(bits):
                return bits_to_text(bits, compressor, max_length=MAX_LENGTH)
            # Check if the checkpoint name indicates BPSK usage
            if not hasattr(sys.modules.get('modulations'), 'BPSKModulator'):
                 raise ImportError("BPSKModulator class not found, cannot use BPSK.")
            bpsk_modulator = BPSKModulator()
            use_bpsk_modulation = True
            print("--> BPSK modulation will be used for LLMZip bits.")

        elif checkpoint_name in TEXT_COMPRESSOR_CHECKPOINTS and DATA_MODULES_AVAILABLE:
            compressor = FineTunedTextCompressor(
                TEXT_COMPRESSOR_CHECKPOINTS[checkpoint_name],
                max_length=MAX_LENGTH,
            )
            print(f"Initialized fine-tuned text compressor ({checkpoint_name})")
        else:
            print(f"No suitable text compressor found or available for checkpoint: {checkpoint_name}")
            return None

    except Exception as e:
        print(f"Error initializing text compressor or BPSK modulator: {e}")
        return None


    print("Loading test texts...")

    
    test_texts = []
    if LOAD_FROM_JSON:
        test_texts = load_texts_from_json(JSON_FILE_PATH, num_texts=num_test_texts)
    elif LOAD_FROM_PKL:
        test_texts = load_and_reconstruct_texts_from_pkl(PKL_FILE_PATH,num_texts=num_test_texts)
    else:
        if DATA_MODULES_AVAILABLE:
            if DATA_MODULE == "wikitext":
                data_module = WikiTextDataModule(batch_size=1, max_length=MAX_LENGTH)
            if DATA_MODULE == "coco":
                data_module = Coco2017DataModule(batch_size=1, max_length=MAX_LENGTH)
            if "pixmo" in DATA_MODULE :
                data_module = PixmoCapDataModule(batch_size=1, max_length=MAX_LENGTH)
            data_module.setup('test')
            test_loader = data_module.test_dataloader()
            loaded_count = 0
            for batch in test_loader:
                if loaded_count >= num_test_texts: break
                text = data_module.tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True)
                if text.strip():
                    test_texts.append(text)
                    loaded_count += 1
            print(f"Loaded {len(test_texts)} texts from {DATA_MODULE}")
    # output_filename = f'{DATA_MODULE}.csv'
    # with open(output_filename, 'w', newline='', encoding='utf-8') as csvfile:
    #     # Create a writer object
    #     writer = csv.writer(csvfile)

    #     # Write the header row
    #     writer.writerow(['text'])

    #     # Write the data rows
    #     for text in test_texts:
    #         writer.writerow([text]) # Wrap text in a list to make it a single column
    # print(f"Successfully saved {len(test_texts)} texts to {output_filename}")
    # plot_rho_by_transform(compressor)
    # print("computing rho correlations")
    plot_roc(compressor)
    exit()

    print("Loading test images...")
    img_size = model_config["img_size"]
    test_images = load_coco_val_images(
        num_images=num_test_images,
        image_size=img_size
    )
    print(f"Loaded {len(test_images)} test images")

    # exit()
    # Initialize metrics
    results = {
        'exact_matches': 0,
        'total_samples': 0,
        'psnr_values': [],
        'bleu1_scores': [],
        'bleu4_scores': [],
        'texts': [],
    }

    # Process test samples
    print("Starting evaluation...")
    for i, text in enumerate(tqdm(test_texts)):
        if not text.strip(): continue
        img = random.choice(test_images).copy()
        # text = text.replace("'s","")
        # text = text.replace(",","")
        original_bits_or_latent = compressor.encode(text) # Shape [1, message_dim]
        
        # Use latent vector directly (fine-tuned) or bits directly (LLMZip without BPSK)
        msg_to_embed = original_bits_or_latent.squeeze(0).cpu().numpy() 
        # Calculate reference text from the original latent/bits
        reference_text = compressor.decode(original_bits_or_latent)[0] #due to batch_decode in decode from finetune_model.py
        if text != reference_text:
            print(f"{text} \n",f"{reference_text}")
            continue
            # if text != reference_text[1:]:
            #     continue
            # else:
            #     print("one white space is messing up")
        # Embed watermark using semanticseal model
        
        # Replace the watermarker.embed() line with:
        watermarked_img, used_power = find_target_power(watermarker, img, msg_to_embed)
        
        # watermarked_img = watermarker.embed(img, msg_to_embed,overwrite_power=None)

        # Calculate PSNR
        psnr_value = compute_psnr(img, watermarked_img)
        # print(f"Achieved PSNR: {psnr_value:.2f} dB (using power: {used_power:.4f})")
        # Handle infinite PSNR
        if psnr_value == float('inf'):
            psnr_value = 100.0 # Assign a high finite value for averaging
        results['psnr_values'].append(psnr_value)


        # Apply transformation
        transformed_img = transform_fn(watermarked_img)

        detected_vector = watermarker.detect(transformed_img) # Returns numpy array

        def bytes_to_bits_list(byte_data: bytes) -> list[int]:
            """
            Converts a bytes object into a list of integers (0s and 1s).
        
            Args:
                byte_data: The input bytes object (or bytearray).
        
            Returns:
                A list of integers (0 or 1) representing the bits.
            """
            bit_list = []
            for byte in byte_data:
                for i in range(8):
                    # Check the i-th bit (from left, most significant)
                    # Shift right by (7-i) positions and check the least significant bit
                    bit = (byte >> (7 - i)) & 1
                    bit_list.append(bit)
            return bit_list
        if use_bpsk_modulation:
            # Demodulate the detected vector to get bits
            detected_bits_list = bpsk_modulator.decode(detected_vector)
                    # Decode the processed message
            
            detected_text = bytes_to_text(detected_bits_list,compressor)
            detected_text_ = bytes_to_text(detected_bits_list[0],compressor)
            # print(detected_bits_list,detected_text,detected_text_)
        else:
            # Convert detected vector (latent/bits) back to tensor for decoder
            input_for_text_decode = torch.tensor(detected_vector, dtype=torch.float32).unsqueeze(0).to(device)
            detected_text = compressor.decode(input_for_text_decode)[0]

        is_exact_match = (detected_text.strip().lower() == reference_text.strip().lower()) # Case-insensitive match
        if is_exact_match:
            results['exact_matches'] += 1
        bleu_scores = compute_bleu(reference_text, detected_text)
        if bleu_scores[3]<1.0:
            print(bleu_scores[3])
            print("[RECON]", detected_text,"[GT]:",reference_text)
        results['bleu1_scores'].append(bleu_scores[0])
        results['bleu4_scores'].append(bleu_scores[3])
        
        rouge_scores = scorer.score(reference_text, detected_text)
        rouge_l_score = rouge_scores['rougeL'].fmeasure  # This gives F1 score
        if 'rouge_l_scores' not in results:
            results['rouge_l_scores'] = []
        results['rouge_l_scores'].append(rouge_l_score)
        
        det_vec  = detected_vector.astype(float).tolist()         # JSON-friendly
        orig_vec = msg_to_embed.astype(float).tolist()
        
        results['texts'].append({
            'original'      : text,
            'reference'     : reference_text,
            'detected'      : detected_text,
            'is_match'      : is_exact_match,
            'bleu1'         : bleu_scores[0],
            'bleu4'         : bleu_scores[3],
            'rouge_l'       : rouge_l_score,
            'psnr'          : psnr_value,
            'orig_vec'      : orig_vec,        # <-- added
            'detected_vec'  : det_vec          # <-- added
        })
        results['total_samples'] += 1


    # --- [Existing code: Calculate summary, save JSON, print summary] ---
    if results['total_samples'] == 0:
        print("No samples were successfully processed")
        return None

    watermark_power = model_config.get("watermark_power", "N/A") # Handle missing power
    summary = {
        'model': model_name,
        'compressor_config': checkpoint_name, # Renamed for clarity
        'transformation': transformation_name,
        'uses_bpsk': use_bpsk_modulation, # Add BPSK flag
        'watermark_strength_setting': watermark_power,
        'exact_match_rate': (results['exact_matches'] / results['total_samples'] * 100) if results['total_samples'] > 0 else 0,
        'avg_psnr': np.mean(results['psnr_values']) if results['psnr_values'] else 0,
        'avg_bleu1': np.mean(results['bleu1_scores']) if results['bleu1_scores'] else 0,
        'avg_bleu4': np.mean(results['bleu4_scores']) if results['bleu4_scores'] else 0,
        'avg_rouge_l': np.mean(results['rouge_l_scores'])  if results['rouge_l_scores'] else 0,
        'total_samples': results['total_samples']
    }

    # Save results to JSON
    try:
        with open(f"{output_prefix}_results.json", 'w') as f:
            json.dump({
                'summary': summary,
                'texts': results['texts'] # Save first 20 examples
            }, f, indent=4, cls=NumpyEncoder)
    except Exception as e:
        print(f"Error saving results to JSON: {e}")


    print(f"\nEvaluation complete for model={model_name}, compressor={checkpoint_name}, transform={transformation_name}, BPSK={use_bpsk_modulation}")
    print(f"Exact match rate: {summary['exact_match_rate']:.2f}%")
    print(f"Average PSNR: {summary['avg_psnr']:.2f} dB")
    print(f"Average BLEU-4: {summary['avg_bleu4']:.4f}")
    print(f"Average Rouge-L: {summary['avg_rouge_l']:.4f}")

    return summary

def main():
    parser = argparse.ArgumentParser(description='Watermark Models Comparison')
    parser.add_argument('--output_dir', type=str, default="watermark_comparison_results",
                      help='Directory to save results')
    parser.add_argument('--num_images', type=int, default=1000,
                      help='Number of test images to use')
    parser.add_argument('--num_texts', type=int, default=2000,
                      help='Number of test texts to use')
    parser.add_argument('--videoseal_strengths', type=float, nargs='+', default=VIDEOSEAL_STRENGTHS,
                      help='VideoSeal watermark strengths to evaluate')
    parser.add_argument('--semanticseal_models', type=str, nargs='+', default=list(SEMANTICSEAL_MODEL_CONFIGS.keys()),
                      help='semanticseal models to evaluate')
    parser.add_argument('--skip_videoseal', action='store_true',
                      help='Skip VideoSeal evaluation')
    parser.add_argument('--skip_semanticseal', action='store_true',
                      help='Skip semanticseal models evaluation')
    args = parser.parse_args()
    
    # Create output directory
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    transformations = list(TRANSFORMATIONS_GLOBAL.keys())
    # transformations = ["jpeg_70"]#["rotate_10","perspective_0.5","jpeg_60","jpeg_50","jpeg_40","gaussian_blur_17"]

    # Choose transformations to evaluate
    # transformations = ["none", "jpeg_75", "rotation_5", "crop_90", "gaussian_noise_10"]
    
    # Choose text compressors to evaluate
    compressors = list(TEXT_COMPRESSOR_CHECKPOINTS.keys())
    
    # Store all results
    all_results = []
    
    # 1. Evaluate VideoSeal with different strengths (if available and not skipped)
    if VIDEOSEAL_AVAILABLE and not args.skip_videoseal:
        print("\nEvaluating VideoSeal with different strengths...")
        for strength in args.videoseal_strengths:
            for checkpoint in compressors:
                for transform in transformations:
                    try:
                        result = evaluate_videoseal_strength(
                            watermark_strength=strength,
                            checkpoint_name=checkpoint,
                            transformation_name=transform,
                            num_test_images=args.num_images,
                            num_test_texts=args.num_texts,
                            output_dir=output_dir
                        )
                        
                        if result:
                            all_results.append(result)
                    except Exception as e:
                        print(f"Error evaluating VideoSeal strength={strength}, checkpoint={checkpoint}, transform={transform}: {e}")
    
    # 2. Evaluate semanticseal models
    if not args.skip_semanticseal:
        print("\nEvaluating semanticseal models...")
        for model_name in args.semanticseal_models:
            for checkpoint in compressors:
                for transform in transformations:
                    result = evaluate_semanticseal_model(
                        model_name=model_name,
                        checkpoint_name=checkpoint,
                        transformation_name=transform,
                        num_test_images=args.num_images,
                        num_test_texts=args.num_texts,
                        output_dir=output_dir,
                        # batch_size=256,
                    )
                    
                    if result:
                        all_results.append(result)
                    
    
    # Convert results to DataFrame
    if all_results:
        df = pd.DataFrame(all_results)
        df.to_csv(f"{output_dir}/all_results_{DATA_MODULE}.csv", index=False)
        
        # Create visualizations
        try:
            # Group models by type (videoseal vs semanticseal)
            df['model_type'] = df['model'].apply(
                lambda x: 'videoseal' if 'videoseal' in x else 'semanticseal'
            )
            
            # Create comparison plots between VideoSeal and semanticseal models
            for transform in transformations:
                # Filter data for this transformation
                trans_df = df[df['transformation'] == transform]
                
                plt.figure(figsize=(12, 8))
                
                # Group by model type and calculate average metrics
                type_metrics = trans_df.groupby('model_type').agg({
                    'exact_match_rate': 'mean',
                    'avg_psnr': 'mean',
                    'avg_bleu4': 'mean'
                }).reset_index()
                
                plt.subplot(1, 3, 1)
                plt.bar(type_metrics['model_type'], type_metrics['exact_match_rate'])
                plt.title(f'Exact Match Rate ({transform})')
                plt.ylabel('Rate (%)')
                
                plt.subplot(1, 3, 2)
                plt.bar(type_metrics['model_type'], type_metrics['avg_bleu4'])
                plt.title(f'BLEU-4 Score ({transform})')
                
                plt.subplot(1, 3, 3)
                plt.bar(type_metrics['model_type'], type_metrics['avg_psnr'])
                plt.title(f'PSNR (dB) ({transform})')
                
                plt.tight_layout()
                plt.savefig(f"{output_dir}/model_type_comparison_{transform}.png")
                plt.close()
            
            # Create heatmaps comparing all models
            for metric, title in [
                ("exact_match_rate", "Text Match Rate (%)"),
                ("avg_bleu4", "BLEU-4 Score"),
                ("avg_psnr", "PSNR (dB)")
            ]:
                # Create a combined model name that includes strength for VideoSeal
                df['model_with_strength'] = df.apply(
                    lambda x: x['model'] if 'videoseal' not in x['model'] 
                    else f"videoseal_{x['watermark_strength']:.1f}",
                    axis=1
                )
                
                # Pivot the data
                pivot_data = df.pivot_table(
                    index="transformation", 
                    columns="model_with_strength", 
                    values=metric
                )
                
                # Create a heatmap
                plt.figure(figsize=(14, 8))
                sns.heatmap(pivot_data, annot=True, cmap='viridis', fmt='.2f')
                plt.title(f'{title} by Model and Transformation')
                plt.xlabel('Model')
                plt.ylabel('Transformation')
                plt.tight_layout()
                plt.savefig(f"{output_dir}/model_comparison_{metric}_heatmap.png")
                plt.close()
            
            # Find the best model for each scenario
            print("\nBest Models Analysis:")
            
            # Best for pristine images (no transformation)
            pristine_df = df[df["transformation"] == "none"]
            best_pristine = pristine_df.loc[pristine_df["exact_match_rate"].idxmax()]
            print(f"\nBest model for pristine images:")
            print(f"  Model: {best_pristine['model']}")
            print(f"  Match rate: {best_pristine['exact_match_rate']:.2f}%")
            print(f"  PSNR: {best_pristine['avg_psnr']:.2f} dB")
            
            # Best for robustness (highest average match rate across all transformations)
            robustness_df = df.groupby('model').agg({
                'exact_match_rate': 'mean',
                'avg_psnr': 'mean',
                'avg_bleu4': 'mean'
            }).reset_index()
            
            best_robust = robustness_df.loc[robustness_df["exact_match_rate"].idxmax()]
            print(f"\nBest model for robustness (across all transformations):")
            print(f"  Model: {best_robust['model']}")
            print(f"  Average match rate: {best_robust['exact_match_rate']:.2f}%")
            print(f"  Average PSNR: {best_robust['avg_psnr']:.2f} dB")
            
            # Best balance of quality and robustness
            robustness_df['balance_score'] = robustness_df['exact_match_rate'] * robustness_df['avg_psnr'] / 100
            best_balance = robustness_df.loc[robustness_df["balance_score"].idxmax()]
            print(f"\nBest model for balanced performance (match rate × PSNR):")
            print(f"  Model: {best_balance['model']}")
            print(f"  Average match rate: {best_balance['exact_match_rate']:.2f}%")
            print(f"  Average PSNR: {best_balance['avg_psnr']:.2f} dB")
            
        except Exception as e:
            print(f"Error creating visualizations: {e}")
    else:
        print("No results available for analysis")

def _pfa(c, d):
    return 0.5 * betainc((d - 1) * 0.5, 0.5, 1.0 - c * c)


# --- rho for a batch ------------------------------------------- #
def _rho_batch(z: torch.Tensor, compressor) -> np.ndarray:
    """
    z : (B,256) unit-norm embeddings on GPU.
    returns (B,) rho values with full double precision.
    """
    with torch.no_grad():
        decoded = compressor.decode(z)
        re_z    = compressor.encode(decoded)
        re_z    = re_z[0] if isinstance(re_z, tuple) else re_z
        if re_z.dim() == 1:
            re_z = re_z.unsqueeze(0)
        re_z = re_z.to(z.device)

    # cosine similarities
    z_n  = z  / (z.norm(dim=1, keepdim=True)  + 1e-12)
    re_n = re_z / (re_z.norm(dim=1, keepdim=True) + 1e-12)
    c    = torch.sum(z_n * re_n, dim=1).clamp(-1.0, 0.999999).cpu().double().numpy()

    # p-value for each sample (only one component, so no min-P correction)
    rho = _pfa(c, 256)       
    return rho#rho.astype(np.float64)
# ------------------------------------------------------------------ #

def evaluate_rho_vs_noise(
    compressor,
    texts,
    noise_std=(0.0, 0.005, 0.01, 0.015, 0.02, 0.025, 0.03),
    batch_size=256,
    out_path="rho_results.json",
):
    device = "cuda"

    # ---------- encode all texts once ---------- #
    Z, originals_ids = [], []
    for i in range(0, len(texts), batch_size):
        batch_txt = texts[i:i+batch_size]
        emb       = compressor.encode(batch_txt)
        emb       = emb[0] if isinstance(emb, tuple) else emb
        Z.append(emb.to(device))

        # store token ids for exact-match test
        ids = batch_txt
        originals_ids.extend(ids)

    Z = torch.cat(Z, dim=0)   # (N,D)
    results = {}

    # ---------- sweep over noise ---------- #
    for sigma in map(float, noise_std):
        z_noisy = Z + torch.randn_like(Z) * sigma
        z_noisy = z_noisy / (z_noisy.norm(dim=1, keepdim=True) + 1e-9)

        decoded = compressor.decode(z_noisy)                    # list[str]
        # tokenise decoded texts (no specials) for 1-to-1 comparison
        dec_ids = decoded

        exact = [int(a == b) for a, b in zip(originals_ids, dec_ids)]   # 1 = perfect
        rho   = _rho_batch(z_noisy, compressor)                         # (N,)

        results[sigma] = {
            "rho_values":   rho.tolist(),
            "exact_match":  exact          # same length as rho_values
        }

    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    return results
    
def plot_rho_by_transform(
    compressor,
    results_dir="/user/videoseal/watermark_comparison_results",
    pattern="nautilus_256_900_pixmo_pixmo_robust_0.01_best_*_results.json",
    batch_size=256,
    out_fig="rho_transform_aaai.pdf",
):
    device     = torch.device("cuda")
    tokenizer  = compressor.model.tokenizer
    rho_dict, match_dict = {}, {}

    for jf in glob.glob(os.path.join(results_dir, pattern)):
        transform = os.path.basename(jf).split("_best_")[1].split("_results")[0]
        print(transform)
        with open(jf) as f:
            items = json.load(f)["texts"]

        det = [it["detected"] for it in items]
        acc = [int(it["detected"] == it.get("reference", it["original"])) for it in items]

        rho_vals = []
        for i in range(0, len(det), batch_size):
            tok = det[i:i+batch_size]
            emb = compressor.encode(tok)
            emb = emb[0] if isinstance(emb, tuple) else emb
            rho_vals.extend(_rho_batch(emb, compressor))

        rho_dict[transform]   = np.asarray(rho_vals)
        match_dict[transform] = np.asarray(acc)

    # ------------------- plot --------------------
    labels = sorted(rho_dict, key=lambda k: np.median(rho_dict[k]))
    data   = [-np.log10(np.clip(rho_dict[k], 1e-300, None)) for k in labels]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6.8, 2.6),
                                   gridspec_kw={"wspace":0.35})

    # (a) confidence distribution
    ax1.violinplot(data, showmedians=True)
    ax1.set_xticks(range(1, len(labels)+1))
    ax1.set_xticklabels(labels, rotation=60, ha="right", fontsize=6)
    ax1.set_ylabel(r"$-\log_{10}\rho$", fontsize=7)
    ax1.set_title("(a) Confidence distribution", fontsize=7)

    # (b) accuracy vs confidence
    match_rate = [match_dict[k].mean() for k in labels]
    med_rho    = [np.median(d) for d in data]
    ax2.bar(labels, match_rate, color="lightgray")
    ax2_t = ax2.twinx()
    ax2_t.plot(labels, med_rho, marker="o", color="tab:blue")
    ax2.set_xticklabels(labels, rotation=60, ha="right", fontsize=6)
    ax2.set_ylabel("Exact-match rate", fontsize=7)
    ax2_t.set_ylabel(r"Median $-\log_{10}\rho$", fontsize=7)
    ax2.set_title("(b) Accuracy vs. confidence", fontsize=7)

    plt.tight_layout()
    fig.savefig(out_fig, dpi=300, bbox_inches="tight")
    print(f"Figure saved to {out_fig}")
    
def rho_metric_correlations(
    compressor,
    results_dir="/user/videoseal/watermark_comparison_results",
    pattern="nautilus_256_900_pixmo_pixmo_robust_0.01_best_*_results.json",
    batch_size=256,
    out_csv="rho_metric_corrs.csv",
):
    device = torch.device("cuda")
    rows   = []

    for jf in tqdm(glob.glob(os.path.join(results_dir, pattern)), desc="scan"):
        tr = os.path.basename(jf).split("_best_")[1].split("_results")[0]
        items = json.load(open(jf))["texts"]

        # --- gather vectors and metrics --------------------------------
        vecs, exact, bleu4, bleu1, rouge, psnr = [], [], [], [], [], []
        for it in items:
            v = it.get("detected_vec")
            if v is None:                      # safety
                continue
            vecs.append(np.asarray(v, dtype=np.float32))
            exact.append(int(it["is_match"]))
            bleu4.append(float(it.get("bleu4", np.nan)))
            bleu1.append(float(it.get("bleu1", np.nan)))
            rouge.append(float(it.get("rouge_l", np.nan)))
            psnr.append(float(it.get("psnr", np.nan)))

        if not vecs:                            # empty transform
            continue

        Z = torch.tensor(np.stack(vecs), device=device)
        Z = Z / (Z.norm(dim=1, keepdim=True) + 1e-9)
        neglog = -np.log10(_rho_batch(Z, compressor))   # confidence

        # helper: safe correlation (returns nan if const)
        def safe_corr(x, y, fn):
            x, y = np.asarray(x), np.asarray(y)
            msk  = ~np.isnan(x) & ~np.isnan(y)
            if msk.sum() < 2 or np.std(x[msk]) == 0 or np.std(y[msk]) == 0:
                return np.nan
            return fn(x[msk], y[msk])[0]

        rows.append(dict(
            transform     = tr,
            N             = len(neglog),
            pearson_exact = safe_corr(neglog, exact, pearsonr),
            spearman_exact= safe_corr(neglog, exact, spearmanr),
            spearman_bleu4= safe_corr(neglog, bleu4, spearmanr),
            spearman_bleu1= safe_corr(neglog, bleu1, spearmanr),
            spearman_rouge= safe_corr(neglog, rouge, spearmanr),
            spearman_psnr = safe_corr(neglog, psnr,  spearmanr),
            median_neglog = float(np.median(neglog)),
            std_neglog    = float(np.std(neglog)),
        ))

    df = pd.DataFrame(rows).sort_values("spearman_exact", ascending=False)
    df.to_csv(out_csv, index=False)
    print(df.to_markdown(index=False, floatfmt=".4g"))
    print(f"\nSaved correlation table to {out_csv}")
    return df

NONCAT   = [                          # mild & moderate transforms
    "identity","horizontal_flip",
    "saturation_0.5","saturation_1.5",
    "contrast_0.5","contrast_1.5",
    "brightness_0.5","brightness_1.5",
    "hue_0.1","hue_-0.1",
    "rotation_5","rotate_10","rotate_90",
    "crop_90",
    "perspective_0.3","perspective_0.5",
    "jpeg_80","jpeg_75","jpeg_70","jpeg_60","jpeg_50",
]

def collect_scores_labels(
    compressor,
    results_dir="/user/videoseal/watermark_comparison_results_rho/pixmo",
    pattern="*_results.json",
):
    neglog_all, err_all = [], []
    BATCH = 256
    for jf in tqdm(glob.glob(os.path.join(results_dir, pattern)), desc="json"):
        print(jf)
        tr = os.path.basename(jf).split("robust_0.01_best_")[1].split("_results")[0]
        if tr not in NONCAT:
            continue
        items = json.load(open(jf))["texts"]

        vecs, labels = [], []
        for it in items:
            v = it.get("detected_vec")
            if v is None: continue
            vecs.append(np.asarray(v, dtype=np.float32))
            labels.append(int(it["is_match"] == 0))   # 1 = error

        if not vecs: continue

        Z = torch.tensor(np.stack(vecs), device="cuda")
        Z = Z / (Z.norm(dim=1, keepdim=True) + 1e-9)

        neglog = []
        for i in range(0, len(Z), BATCH):
            rho  = _rho_batch(Z[i:i+BATCH], compressor)
            neglog.extend(rho)#-np.log10(rho))

        neglog_all.extend(neglog)
        err_all.extend(labels)

    return np.asarray(neglog_all), np.asarray(err_all)


# ------------------- main ---------------------------------------------
def plot_roc(compressor,
             out_fig="roc_neglogrho_pixmo.pdf",
             results_dir="/user/videoseal/watermark_comparison_results_rho/pixmo"):
    scores, errs = collect_scores_labels(compressor, results_dir)

    # higher score = better ⇒ we flip sign because roc_curve expects
    fpr, tpr, thr = roc_curve(errs, scores)    # -scores: lower conf = error
    targets = [1e-10,1e-6,1e-4,1e-3,1e-2,1e-1,1]#0.01                       # 1 % false positives
    for target in targets:
        idx     = np.where(fpr <= target)[0][-1]   # last index before we exceed target
        best_thr = thr[idx]
        print(f"threshold={best_thr:.2f},  FPR={fpr[idx]:.11f},  TPR={tpr[idx]:.4f}")

    auc_val = auc(fpr, tpr)
    plt.figure(figsize=(3.2,3.2))
    plt.plot(fpr, tpr, label=f"AUROC = {auc_val:.3f}")
    plt.plot([0,1],[0,1],"k--", lw=0.6)
    plt.xlabel("False–positive rate")
    plt.ylabel("True–positive rate")
    # plt.title(r"ROC for confidence $-\log_{10}\rho$")
    plt.grid(True, ls="--", alpha=.4)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_fig, dpi=300, bbox_inches="tight")
    print(f"ROC curve saved ➜ {out_fig}")

# ---------- run (compressor must be initialised) ----------------------
# plot_roc(compressor)


if __name__ == "__main__":
    main()
