#!/usr/bin/env python3
"""
Watermark Evaluation Script
This script provides functions to:
1. Encode messages into watermarked images and save them with correspondence (targeting 42 dB PSNR)
2. Extract and decode messages from watermarked images and compare with originals
"""

import os
import sys
import torch
import numpy as np
import json
import csv
import argparse
from PIL import Image
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import random
from typing import Dict, List, Tuple, Optional
import io
from datasets import load_dataset
from PIL import Image as PILImage

# Import from XP1.py - adjust path as needed
sys.path.append('/path/to/your/XP1/script/directory')  # Update this path
from XP1 import *#(
#     # Classes and functions we need
#     SwiftWatermarker, LLMZipTextCompressor, FineTunedTextCompressor,
#     load_coco_val_images, text_to_bits, bits_to_text, compute_bleu, compute_psnr,
#     NumpyEncoder, TRANSFORMATIONS, SWIFT_MODEL_CONFIGS, 
#     TEXT_COMPRESSOR_CHECKPOINTS, VIDEOSEAL_AVAILABLE,
#     device, MAX_LENGTH, find_target_power,
#     # Text loading functions
#     load_and_reconstruct_texts_from_pkl, load_texts_from_json,
#     WikiTextDataModule, Coco2017DataModule, PixmoCapDataModule,
#     DATA_MODULES_AVAILABLE
# )

# Import VideoSeal if available
if VIDEOSEAL_AVAILABLE:
    from XP1 import setup_model_from_checkpoint, vs_psnr, bit_accuracy
    import torchvision

# Configuration from XP1.py
PKL_FILE_PATH = '/home/gevennou/text_reconstruction/trial/challenge_set_top_5percent.pkl'
JSON_FILE_PATH = '/home/gevennou/videoseal/watermark_comparison_results/videoseal_strength_1.2_llmzip_opt125m_wikitext_none_results.json'
LOAD_FROM_JSON = False  # Set to True to use the JSON file, False to use old method
LOAD_FROM_PKL = True
DATA_MODULE = "pixmo_challenge"

# Add this new function after the existing load_texts_for_embedding function
def load_emu_edit_images(
    num_images: int = 2000,
    image_size: int = 256,
    split: str = "test"
) -> List[PILImage.Image]:
    """
    Load images from the facebook/emu_edit_test_set_generations dataset
    
    Args:
        num_images: Number of images to load
        image_size: Size to resize images to
        split: Dataset split to use (default: "test")
        
    Returns:
        List of PIL Images
    """
    
    print(f"Loading images from facebook/emu_edit_test_set_generations dataset...")
    
    try:
        # Load the dataset
        dataset = load_dataset("facebook/emu_edit_test_set_generations", split=split)
        
        # Define the transform
        from torchvision import transforms
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
        ])
        
        images = []
        id_samples = []
        # Load and process images
        for i, example in enumerate(dataset):
            if i >= num_images:
                break
                
            try:
                # Get the image from the 'image' key
                img = example['image']
                id_sample = example["idx"]
                
                # Ensure it's a PIL Image
                if not isinstance(img, PILImage.Image):
                    # Convert if needed
                    img = PILImage.fromarray(img) if hasattr(img, 'shape') else img
                
                # Convert to RGB if needed
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                
                # Apply transform
                img = transform(img)
                images.append(img)
                id_samples.append(id_sample)
            except Exception as e:
                print(f"Error processing image {i}: {e}")
                continue
        
        print(f"Successfully loaded {len(images)} images from emu_edit dataset")
        return images, id_samples
        
    except Exception as e:
        print(f"Error loading emu_edit dataset: {e}")
        print("Falling back to COCO images...")
        # Fallback to COCO images if emu_edit fails
        return load_coco_val_images(num_images=num_images, image_size=image_size)

def load_texts_for_embedding(
    num_texts: int = 2000,
    max_length: int = 30,
    compressor=None,
    filter_by_compression: bool = True
) -> List[str]:
    """
    Load texts for embedding using the same logic as XP1.py
    
    Args:
        num_texts: Number of texts to load
        max_length: Maximum length for text processing
        compressor: Text compressor for filtering (optional)
        filter_by_compression: Whether to filter texts by compression length
        
    Returns:
        List of text strings ready for embedding
    """
    
    print(f"Loading texts from DATA_MODULE: {DATA_MODULE}")
    print(f"LOAD_FROM_PKL: {LOAD_FROM_PKL}")
    print(f"LOAD_FROM_JSON: {LOAD_FROM_JSON}")
    
    test_texts = []
    
    if LOAD_FROM_JSON:
        print(f"Loading texts from JSON: {JSON_FILE_PATH}")
        test_texts = load_texts_from_json(JSON_FILE_PATH, num_texts=num_texts)
        
    elif LOAD_FROM_PKL:
        print(f"Loading texts from PKL: {PKL_FILE_PATH}")
        test_texts = load_and_reconstruct_texts_from_pkl(PKL_FILE_PATH, num_texts=num_texts)
        
    else:
        # Load from data modules (WikiText/COCO/PixmoCap)
        if DATA_MODULES_AVAILABLE:
            try:
                print(f"Loading texts from data module: {DATA_MODULE}")
                
                if DATA_MODULE == "wikitext":
                    data_module = WikiTextDataModule(batch_size=1, max_length=max_length)
                elif DATA_MODULE == "coco":
                    data_module = Coco2017DataModule(batch_size=1, max_length=max_length)
                elif "pixmo" in DATA_MODULE:
                    data_module = PixmoCapDataModule(batch_size=1, max_length=max_length)
                else:
                    raise ValueError(f"Unknown data module: {DATA_MODULE}")
                
                data_module.setup('test')
                test_loader = data_module.test_dataloader()
                
                print(f"Loading up to {max_length} tokens from {DATA_MODULE}")
                
                loaded_count = 0
                for batch in tqdm(test_loader, total=min(num_texts, 2000)):
                    if loaded_count >= num_texts:
                        break
                    
                    # Check if batch has enough tokens
                    if len(batch['input_ids'][0]) < max_length:
                        continue
                    
                    # Decode the text
                    text = data_module.tokenizer.decode(
                        batch['input_ids'][0][:max_length], 
                        skip_special_tokens=True
                    )
                    
                    # Filter by compression if compressor is provided
                    if filter_by_compression and compressor is not None:
                        try:
                            # Test compression length
                            if hasattr(compressor, 'text_zipper'):
                                # LLMZip compressor
                                bit_msg_tensor, length_check = text_to_bits(
                                    text, compressor.text_zipper, 
                                    max_bits=500, max_length=max_length
                                )
                                if length_check <= 256:
                                    continue  # Skip texts that compress too well
                            else:
                                # Fine-tuned compressor - no specific filtering needed
                                pass
                        except Exception as e:
                            print(f"Error in compression filtering: {e}")
                            continue
                    
                    if text.strip():  # Skip empty texts
                        test_texts.append(text)
                        loaded_count += 1
                        
                print(f"Loaded {len(test_texts)} texts from {DATA_MODULE}")
                
            except Exception as e:
                print(f"Error loading from data module: {e}")
                raise
        else:
            raise ValueError("No data loading method available")
    
    print(f"Successfully loaded {len(test_texts)} texts for embedding")
    return test_texts

def encode_and_save_watermarked_images(
    output_dir: str,
    model_config: Dict,
    num_texts: int = 2000,
    num_images: int = 100,
    image_size: int = 256,
    target_psnr: float = 42.0,
    psnr_tolerance: float = 0.2,
    max_psnr_iterations: int = 12,
    coco_val_dir: str = "/home/gevennou/BIG_storage/Paper2/coco_dataset/val2017",
    use_emu_edit_images: bool = True,  
    texts: Optional[List[str]] = None  
) -> Dict[str, any]:
    """
    Encode messages into watermarked images and save them with correspondence.
    Enforces target PSNR of 42 dB through binary search.
    
    Args:
        output_dir: Directory to save watermarked images and correspondence
        model_config: Configuration for the watermarking model
        num_texts: Number of texts to load (if texts not provided)
        num_images: Number of images to use as carriers
        image_size: Size to resize images to
        target_psnr: Target PSNR value (default: 42.0 dB)
        psnr_tolerance: Tolerance for PSNR targeting (default: 0.2 dB)
        max_psnr_iterations: Maximum iterations for PSNR binary search
        coco_val_dir: Directory containing COCO validation images
        texts: Optional list of texts to encode (if None, will load from configured source)
        
    Returns:
        Dictionary containing results and correspondence information
    """
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize model based on type
    model_type = model_config.get('type', 'swift')
    
    if model_type == 'videoseal':
        if not VIDEOSEAL_AVAILABLE:
            raise ValueError("VideoSeal not available")
        
        print("Initializing VideoSeal model...")
        model = setup_model_from_checkpoint('videoseal')
        model.eval()
        model.compile()
        model.to(device)
        
        # Initial watermark strength (will be adjusted per image)
        base_watermark_strength = model_config.get('watermark_strength', 1.0)
        
        # For converting between PIL and tensor
        to_tensor = torchvision.transforms.ToTensor()
        to_pil = torchvision.transforms.ToPILImage()
        
        print(f"VideoSeal model initialized with base strength {base_watermark_strength}")
        
    elif model_type == 'swift':
        print("Initializing Swift model...")
        watermarker = SwiftWatermarker(model_config)
        print(f"Swift model initialized: {model_config.get('model_name', 'unknown')}")
        
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # Initialize text compressor
    compressor_name = model_config.get('text_compressor', 'llmzip_opt125m')
    print(f"Initializing text compressor: {compressor_name}")
    
    try:
        if "llmzip" in compressor_name:
            # Check if there's an adapter path
            adapter_path = None
            # You can add adapter path logic here if needed
            
            text_compressor = LLMZipTextCompressor(
                TEXT_COMPRESSOR_CHECKPOINTS[compressor_name],
                max_length=MAX_LENGTH
            )
            print(f"Initialized LLMZip text compressor")
        else:
            text_compressor = FineTunedTextCompressor(
                TEXT_COMPRESSOR_CHECKPOINTS[compressor_name],
                max_length=MAX_LENGTH
            )
            print(f"Initialized fine-tuned text compressor")
    except Exception as e:
        print(f"Error initializing text compressor: {e}")
        raise
    
    if texts is None:
        texts = load_texts_for_embedding(
            num_texts=num_texts,
            max_length=MAX_LENGTH,
            compressor=text_compressor,
            filter_by_compression=True
        )
    
    print(f"Processing {len(texts)} texts for embedding")
    
    # Load carrier images - UPDATED SECTION
    print("Loading carrier images...")
    if use_emu_edit_images:
        carrier_images,carrier_idx = load_emu_edit_images(
            num_images=num_images,
            image_size=image_size
        )
    else:
        carrier_images = load_coco_val_images(
            coco_val_dir=coco_val_dir,
            image_size=image_size,
            num_images=num_images
        )
    
    print(f"Loaded {len(carrier_images)} carrier images")
    
    # Results storage - UPDATE to include image source info
    results = {
        'model_config': model_config,
        'data_source': {
            'PKL_FILE_PATH': PKL_FILE_PATH if LOAD_FROM_PKL else None,
            'JSON_FILE_PATH': JSON_FILE_PATH if LOAD_FROM_JSON else None,
            'DATA_MODULE': DATA_MODULE,
            'LOAD_FROM_PKL': LOAD_FROM_PKL,
            'LOAD_FROM_JSON': LOAD_FROM_JSON,
            'num_texts_requested': num_texts,
            'num_texts_loaded': len(texts)
        },
        'image_source': {
            'use_emu_edit_images': use_emu_edit_images,
            'dataset': 'facebook/emu_edit_test_set_generations' if use_emu_edit_images else 'COCO',
            'num_images_requested': num_images,
            'num_images_loaded': len(carrier_images),
            'image_size': image_size
        },
        'target_psnr': target_psnr,
        'psnr_tolerance': psnr_tolerance,
        'correspondence': [],
        'statistics': {
            'total_processed': 0,
            'successful_embeddings': 0,
            'failed_embeddings': 0,
            'encoding_mismatches': 0,
            'avg_psnr': 0,
            'psnr_values': [],
            'avg_used_strength': 0,
            'used_strengths': [],
            'psnr_targeting_stats': {
                'within_tolerance': 0,
                'outside_tolerance': 0,
                'avg_iterations': 0,
                'iterations_list': []
            }
        }
    }
    
    # Process each text
    print(f"Processing {len(texts)} texts with target PSNR {target_psnr} dB...")
    
    processed_count = 0
    for i, text in enumerate(tqdm(texts, desc="Encoding texts")):
        if not text.strip():
            continue
            
        # Select random carrier image
        carrier_img = carrier_images[i].copy()
        
        # Encode text to message format
        if model_type == 'videoseal':
            # Encode text using appropriate compressor
            if "llmzip" in compressor_name:
                bit_msg, original_length = text_to_bits(
                    text, text_compressor.text_zipper, 
                    max_bits=256, max_length=MAX_LENGTH
                )
                bit_msg_tensor = bit_msg.unsqueeze(0).to(device)
                
                # Get reference text (what the model should decode)
                reference_text = bits_to_text(
                    bit_msg_tensor[0], text_compressor.text_zipper, 
                    max_length=MAX_LENGTH
                )
            else:
                # For fine-tuned compressor, use latent representation
                bit_msg_tensor = text_compressor.encode(text)
                reference_text = text_compressor.decode(bit_msg_tensor)
            
            # Skip if encoding/decoding doesn't match original
            if text.strip() != reference_text.strip():
                results['statistics']['encoding_mismatches'] += 1
                continue
            
            # Convert carrier image to tensor
            img_tensor = to_tensor(carrier_img).unsqueeze(0).float().to(device)
            
            # Binary search for target PSNR
            low = 0.0
            high = 2.0
            best_power = base_watermark_strength
            best_diff = float('inf')
            best_img_w = None
            iterations_used = 0
            
            for iteration in range(max_psnr_iterations):
                iterations_used += 1
                
                # Adjust the watermark strength
                model.blender.scaling_w = best_power
                
                # Embed watermark
                with torch.no_grad():
                    original_get_random_msg = model.embedder.unet.msg_processor.get_random_msg
                    model.embedder.unet.msg_processor.get_random_msg = lambda bsz, nb_repetitions: bit_msg_tensor
                    outputs = model.embed(img_tensor, is_video=False, lowres_attenuation=True)
                    model.embedder.unet.msg_processor.get_random_msg = original_get_random_msg
                    
                    imgs_w = outputs["imgs_w"]
                
                # Calculate PSNR
                psnr_result = vs_psnr(imgs_w, img_tensor)
                if psnr_result.numel() == 1:
                    current_psnr = psnr_result.item()
                else:
                    current_psnr = psnr_result.mean().item()
                
                diff = abs(current_psnr - target_psnr)
                
                # Update best result if this is closer to target
                if diff < best_diff:
                    best_diff = diff
                    best_img_w = imgs_w.clone()
                
                # Check if we're within tolerance
                if diff <= psnr_tolerance:
                    break
                
                # Binary search adjustment
                if current_psnr > target_psnr:
                    low = best_power
                else:
                    high = best_power
                
                best_power = (low + high) / 2
            
            # Use the best watermarked image found
            if best_img_w is not None:
                imgs_w = best_img_w
            
            # Calculate final PSNR
            final_psnr = vs_psnr(imgs_w, img_tensor)
            if final_psnr.numel() == 1:
                psnr_value = final_psnr.item()
            else:
                psnr_value = final_psnr.mean().item()
            
            # Convert to PIL
            watermarked_img = to_pil(imgs_w[0].cpu())
            used_strength = best_power
            
            # Update PSNR targeting statistics
            if abs(psnr_value - target_psnr) <= psnr_tolerance:
                results['statistics']['psnr_targeting_stats']['within_tolerance'] += 1
            else:
                results['statistics']['psnr_targeting_stats']['outside_tolerance'] += 1
            
            results['statistics']['psnr_targeting_stats']['iterations_list'].append(iterations_used)
            
        elif model_type == 'swift':
            # For Swift, encode to appropriate format
            if "llmzip" in compressor_name:
                bit_msg, original_length = text_to_bits(
                    text, text_compressor.text_zipper, 
                    max_bits=256, max_length=MAX_LENGTH
                )
                bit_msg_tensor = bit_msg.unsqueeze(0).to(device)
                
                # Convert to numpy for Swift
                msg_to_embed = bit_msg_tensor.squeeze(0).cpu().numpy()
                reference_text = bits_to_text(
                    bit_msg_tensor[0], text_compressor.text_zipper,
                    max_length=MAX_LENGTH
                )
            else:
                # Use latent vector directly
                encoded_msg = text_compressor.encode(text)
                msg_to_embed = encoded_msg.squeeze(0).cpu().numpy()
                reference_text = text_compressor.decode(encoded_msg)
            
            # Skip if encoding/decoding doesn't match original
            if text.strip() != reference_text.strip():
                results['statistics']['encoding_mismatches'] += 1
                continue
            
            # Use the find_target_power function from XP1.py for Swift models
            watermarked_img, used_power = find_target_power(
                watermarker, carrier_img, msg_to_embed, 
                target_psnr=target_psnr, 
                tolerance=psnr_tolerance, 
                max_iterations=max_psnr_iterations
            )
            
            # Calculate final PSNR
            psnr_value = compute_psnr(carrier_img, watermarked_img)
            if psnr_value == float('inf'):
                psnr_value = 100.0
            
            used_strength = used_power
            
            # Update PSNR targeting statistics
            if abs(psnr_value - target_psnr) <= psnr_tolerance:
                results['statistics']['psnr_targeting_stats']['within_tolerance'] += 1
            else:
                results['statistics']['psnr_targeting_stats']['outside_tolerance'] += 1
            
            # For Swift, we don't track iterations in the same way
            results['statistics']['psnr_targeting_stats']['iterations_list'].append(max_psnr_iterations)
        
        # Save watermarked image
        img_filename = f"watermarked_{carrier_idx[i]}.png"
        img_path = os.path.join(output_dir, img_filename)
        watermarked_img.save(img_path)
        
        # Store correspondence
        correspondence_entry = {
            'image_id': processed_count,
            'text_index': i,  # Original index in the text list
            'image_filename': img_filename,
            'image_path': img_path,
            'original_text': text,
            'reference_text': reference_text,
            'psnr': psnr_value,
            'used_strength': used_strength,
            'target_psnr': target_psnr,
            'psnr_diff': abs(psnr_value - target_psnr),
            'within_tolerance': abs(psnr_value - target_psnr) <= psnr_tolerance,
            'carrier_image_size': carrier_img.size,
            'watermarked_image_size': watermarked_img.size
        }
        
        results['correspondence'].append(correspondence_entry)
        results['statistics']['successful_embeddings'] += 1
        results['statistics']['psnr_values'].append(psnr_value)
        results['statistics']['used_strengths'].append(used_strength)
        
        processed_count += 1

    
    # Calculate summary statistics
    results['statistics']['total_processed'] = len(texts)
    
    if results['statistics']['psnr_values']:
        results['statistics']['avg_psnr'] = np.mean(results['statistics']['psnr_values'])
    
    if results['statistics']['used_strengths']:
        results['statistics']['avg_used_strength'] = np.mean(results['statistics']['used_strengths'])
    
    # Calculate PSNR targeting statistics
    psnr_stats = results['statistics']['psnr_targeting_stats']
    if psnr_stats['iterations_list']:
        psnr_stats['avg_iterations'] = np.mean(psnr_stats['iterations_list'])
    
    total_processed = psnr_stats['within_tolerance'] + psnr_stats['outside_tolerance']
    psnr_stats['success_rate'] = (psnr_stats['within_tolerance'] / total_processed * 100) if total_processed > 0 else 0
    
    # Save correspondence as JSON
    correspondence_json_path = os.path.join(output_dir, 'correspondence.json')
    with open(correspondence_json_path, 'w') as f:
        json.dump(results, f, indent=2, cls=NumpyEncoder)
    
    # Save correspondence as CSV for easier viewing
    correspondence_csv_path = os.path.join(output_dir, 'correspondence.csv')
    df = pd.DataFrame(results['correspondence'])
    df.to_csv(correspondence_csv_path, index=False)
    
    # Save summary statistics
    summary_path = os.path.join(output_dir, 'summary_stats.json')
    with open(summary_path, 'w') as f:
        json.dump({
            'data_source': results['data_source'],
            'model_config': results['model_config'],
            'statistics': results['statistics']
        }, f, indent=2, cls=NumpyEncoder)
    
    print(f"\nEncoding complete!")
    print(f"Total texts processed: {results['statistics']['total_processed']}")
    print(f"Encoding mismatches (skipped): {results['statistics']['encoding_mismatches']}")
    print(f"Successful embeddings: {results['statistics']['successful_embeddings']}")
    print(f"Failed embeddings: {results['statistics']['failed_embeddings']}")
    print(f"Average PSNR: {results['statistics']['avg_psnr']:.2f} dB (target: {target_psnr} dB)")
    print(f"Average used strength: {results['statistics']['avg_used_strength']:.4f}")
    print(f"PSNR within tolerance: {psnr_stats['within_tolerance']}/{total_processed} ({psnr_stats['success_rate']:.1f}%)")
    print(f"Data source: {DATA_MODULE} ({'PKL' if LOAD_FROM_PKL else 'JSON' if LOAD_FROM_JSON else 'DataModule'})")
    print(f"Image source: {'EMU Edit' if use_emu_edit_images else 'COCO'}")
    print(f"Results saved to: {output_dir}")
    
    return results

def extract_and_compare_messages(
    watermarked_dir: str,
    original_correspondence_path: str,
    model_config: Dict,
    transformation_name: str = "none",
    output_dir: Optional[str] = None
) -> Dict[str, any]:
    """
    Extract and decode messages from watermarked images and compare with originals.
    
    Args:
        watermarked_dir: Directory containing watermarked images
        original_correspondence_path: Path to original correspondence JSON file
        model_config: Configuration for the watermarking model
        transformation_name: Name of transformation to apply before detection
        output_dir: Directory to save extraction results (optional)
        
    Returns:
        Dictionary containing extraction results and comparisons
    """
    
    # Load original correspondence
    with open(original_correspondence_path, 'r') as f:
        original_data = json.load(f)
    
    original_correspondence = original_data['correspondence']
    original_model_config = original_data['model_config']
    data_source = original_data.get('data_source', {})
    
    print(f"Loaded correspondence for {len(original_correspondence)} images")
    print(f"Original model config: {original_model_config}")
    print(f"Data source: {data_source}")
    
    # Set output directory
    if output_dir is None:
        output_dir = os.path.join(watermarked_dir, 'extraction_results')
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize model based on type
    model_type = model_config.get('type', 'swift')
    
    if model_type == 'videoseal':
        if not VIDEOSEAL_AVAILABLE:
            raise ValueError("VideoSeal not available")
        
        print("Initializing VideoSeal model...")
        model = setup_model_from_checkpoint('videoseal')
        model.eval()
        model.compile()
        model.to(device)
        
        # Use the same watermark strength as original encoding
        watermark_strength = original_model_config.get('watermark_strength', 1.0)
        model.blender.scaling_w *= watermark_strength
        
        # For converting between PIL and tensor
        to_tensor = torchvision.transforms.ToTensor()
        
        print(f"VideoSeal model initialized with strength {watermark_strength}")
        
    elif model_type == 'swift':
        print("Initializing Swift model...")
        watermarker = SwiftWatermarker(model_config)
        print(f"Swift model initialized: {model_config.get('model_name', 'unknown')}")
        
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # Initialize text compressor
    compressor_name = model_config.get('text_compressor')
    print(f"Initializing text compressor: {compressor_name}")
    
    try:
        if "llmzip" in compressor_name:
            # Check if there's an adapter path
            adapter_path = None
            # You can add adapter path logic here if needed
            
            text_compressor = LLMZipTextCompressor(
                TEXT_COMPRESSOR_CHECKPOINTS[compressor_name],
                max_length=MAX_LENGTH
            )
            print(f"Initialized LLMZip text compressor")
        else:
            text_compressor = FineTunedTextCompressor(
                TEXT_COMPRESSOR_CHECKPOINTS[compressor_name],
                max_length=MAX_LENGTH
            )
            print(f"Initialized fine-tuned text compressor")
    except Exception as e:
        print(f"Error initializing text compressor: {e}")
        raise
    
    # Get transformation function
    transform_fn = TRANSFORMATIONS.get(transformation_name, TRANSFORMATIONS["none"])
    print(f"Using transformation: {transformation_name}")
    
    # Results storage
    results = {
        'model_config': model_config,
        'original_model_config': original_model_config,
        'data_source': data_source,
        'transformation': transformation_name,
        'original_correspondence_path': original_correspondence_path,
        'extractions': [],
        'statistics': {
            'total_processed': 0,
            'successful_extractions': 0,
            'failed_extractions': 0,
            'exact_matches': 0,
            'bleu1_scores': [],
            'bleu4_scores': [],
            'bit_accuracy_scores': [],
            'avg_bleu1': 0,
            'avg_bleu4': 0,
            'avg_bit_accuracy': 0,
            'exact_match_rate': 0
        }
    }
    
    # Process each image
    print(f"Extracting messages from {len(original_correspondence)} images...")
    
    for entry in tqdm(original_correspondence, desc="Extracting messages"):
        # Load watermarked image
        img_path = entry['image_path']    
        idx_png = img_path.split("_")[-1]
        idx, _ = idx_png.split(".")
        curr_dir = watermarked_dir+"/"+idx
        img_path = curr_dir+"/"+f"{idx}_0_synthetic.png"
        # BIG_storage/Paper3/ip2p_swift_watermarked_edited/0
        # exit()
        if not os.path.exists(img_path):


            # Try relative path from watermarked_dir
            img_path = os.path.join(watermarked_dir, entry['image_filename'])
        
        if not os.path.exists(img_path):
            print(f"Warning: Image not found: {img_path}")
            results['statistics']['failed_extractions'] += 1
            continue
        
        watermarked_img = Image.open(img_path).convert('RGB')
        
        # Apply transformation
        transformed_img = transform_fn(watermarked_img)
        
        # Extract message
        if model_type == 'videoseal':
            # Convert to tensor
            img_tensor = to_tensor(transformed_img).unsqueeze(0).float().to(device)
            
            # Detect watermark
            with torch.no_grad():
                detect_outputs = model.detect(img_tensor, is_video=False)
                preds = detect_outputs["preds"]
                bit_preds = preds[:, 1:]  # Remove first bit
                pred_bits = (bit_preds > 0).float()
            
            # We need the original message for bit accuracy calculation
            original_text = entry['original_text']
            
            # Re-encode original text to get bit accuracy
            if "llmzip" in compressor_name:
                original_bit_msg, _ = text_to_bits(
                    original_text, text_compressor.text_zipper, 
                    max_bits=256, max_length=MAX_LENGTH
                )
                original_bit_msg_tensor = original_bit_msg.unsqueeze(0).to(device)
                
                # Calculate bit accuracy
                bit_accuracy_score = bit_accuracy(bit_preds, original_bit_msg_tensor).item()
                
                # Decode message
                detected_text = bits_to_text(
                    pred_bits[0], text_compressor.text_zipper,
                    max_length=MAX_LENGTH
                )
            else:
                # For fine-tuned compressor
                original_latent = text_compressor.encode(original_text)
                bit_accuracy_score = bit_accuracy(bit_preds, original_latent).item()
                detected_text = text_compressor.decode(pred_bits)
            
        elif model_type == 'swift':
            # Detect watermark
            detected_vector = watermarker.detect(transformed_img)
            
            # Decode message
            if "llmzip" in compressor_name:
                detected_text = bits_to_text(
                    torch.tensor(detected_vector), text_compressor.text_zipper,
                    max_length=MAX_LENGTH
                )
            else:
                input_tensor = torch.tensor(detected_vector, dtype=torch.float32).unsqueeze(0).to(device)
                detected_text = text_compressor.decode(input_tensor)
                print(detected_text)
            bit_accuracy_score = None  # Not directly applicable for Swift
        
        # Compare with original
        original_text = entry['original_text']
        reference_text = entry['reference_text']
        
        # Use case-insensitive comparison
        is_exact_match = (detected_text.strip().lower() == reference_text.strip().lower())
        
        # Calculate BLEU scores (compare with reference text, not original)
        bleu_scores = compute_bleu(reference_text, detected_text)
        
        # Store extraction result
        extraction_entry = {
            'image_id': entry['image_id'],
            'text_index': entry.get('text_index', entry['image_id']),
            'image_filename': entry['image_filename'],
            'original_text': original_text,
            'reference_text': reference_text,
            'detected_text': detected_text,
            'is_exact_match': is_exact_match,
            'bleu1': bleu_scores[0],
            'bleu2': bleu_scores[1],
            'bleu3': bleu_scores[2],
            'bleu4': bleu_scores[3],
            'bit_accuracy': bit_accuracy_score,
            'transformation': transformation_name,
            'original_psnr': entry.get('psnr', None),
            'used_strength': entry.get('used_strength', None)
        }
        
        results['extractions'].append(extraction_entry)
        results['statistics']['successful_extractions'] += 1
        
        if is_exact_match:
            results['statistics']['exact_matches'] += 1
        
        results['statistics']['bleu1_scores'].append(bleu_scores[0])
        results['statistics']['bleu4_scores'].append(bleu_scores[3])
        
        if bit_accuracy_score is not None:
            results['statistics']['bit_accuracy_scores'].append(bit_accuracy_score)

    
    # Calculate summary statistics
    results['statistics']['total_processed'] = len(original_correspondence)
    
    if results['statistics']['bleu1_scores']:
        results['statistics']['avg_bleu1'] = np.mean(results['statistics']['bleu1_scores'])
        results['statistics']['avg_bleu4'] = np.mean(results['statistics']['bleu4_scores'])
    
    if results['statistics']['bit_accuracy_scores']:
        results['statistics']['avg_bit_accuracy'] = np.mean(results['statistics']['bit_accuracy_scores'])
    
    if results['statistics']['successful_extractions'] > 0:
        results['statistics']['exact_match_rate'] = (
            results['statistics']['exact_matches'] / 
            results['statistics']['successful_extractions'] * 100
        )
    
    # Save extraction results
    extraction_json_path = os.path.join(output_dir, 'extraction_results.json')
    with open(extraction_json_path, 'w') as f:
        json.dump(results, f, indent=2, cls=NumpyEncoder)
    
    # Save extraction results as CSV
    extraction_csv_path = os.path.join(output_dir, 'extraction_results.csv')
    df = pd.DataFrame(results['extractions'])
    df.to_csv(extraction_csv_path, index=False)
    
    print(f"\nExtraction complete!")
    print(f"Successful extractions: {results['statistics']['successful_extractions']}")
    print(f"Failed extractions: {results['statistics']['failed_extractions']}")
    print(f"Exact match rate: {results['statistics']['exact_match_rate']:.2f}%")
    print(f"Average BLEU-1: {results['statistics']['avg_bleu1']:.4f}")
    print(f"Average BLEU-4: {results['statistics']['avg_bleu4']:.4f}")
    if results['statistics']['bit_accuracy_scores']:
        print(f"Average bit accuracy: {results['statistics']['avg_bit_accuracy']:.4f}")
    print(f"Results saved to: {output_dir}")
    
    return results

def batch_process_directories(
    input_dirs: List[str],
    model_configs: Dict[str, Dict],
    transformations: List[str] = ["none"],
    output_base_dir: str = "batch_results"
) -> Dict[str, any]:
    """
    Process multiple directories with different model configurations.
    
    Args:
        input_dirs: List of directories containing watermarked images and correspondence
        model_configs: Dictionary of model configurations
        transformations: List of transformations to apply
        output_base_dir: Base directory for saving batch results
        
    Returns:
        Dictionary containing batch processing results
    """
    
    os.makedirs(output_base_dir, exist_ok=True)
    
    batch_results = {
        'processed_directories': [],
        'summary_statistics': {}
    }
    
    for input_dir in input_dirs:
        print(f"\n{'='*60}")
        print(f"Processing directory: {input_dir}")
        print(f"{'='*60}")
        
        # Look for correspondence file
        correspondence_file = os.path.join(input_dir, 'correspondence.json')
        if not os.path.exists(correspondence_file):
            print(f"Warning: No correspondence.json found in {input_dir}")
            continue
        
        # Load correspondence to determine model type
        with open(correspondence_file, 'r') as f:
            correspondence_data = json.load(f)
        
        original_model_config = correspondence_data.get('model_config', {})
        data_source = correspondence_data.get('data_source', {})
        
        # Find matching model configuration
        matching_config = None
        for config_name, config in model_configs.items():
            if config.get('type') == original_model_config.get('type'):
                matching_config = config
                break
        
        if matching_config is None:
            print(f"Warning: No matching model configuration found for {input_dir}")
            continue
        
        # Process with each transformation
        for transformation in transformations:
            print(f"\nProcessing with transformation: {transformation}")
            
            output_dir = os.path.join(
                output_base_dir,
                os.path.basename(input_dir),
                transformation
            )
            
            try:
                results = extract_and_compare_messages(
                    watermarked_dir=input_dir,
                    original_correspondence_path=correspondence_file,
                    model_config=matching_config,
                    transformation_name=transformation,
                    output_dir=output_dir
                )
                
                batch_results['processed_directories'].append({
                    'input_dir': input_dir,
                    'transformation': transformation,
                    'output_dir': output_dir,
                    'data_source': data_source,
                    'results': results['statistics']
                })
                
            except Exception as e:
                print(f"Error processing {input_dir} with {transformation}: {e}")
                continue
    
    # Calculate summary statistics across all directories
    if batch_results['processed_directories']:
        # Group by transformation
        by_transformation = {}
        for result in batch_results['processed_directories']:
            trans = result['transformation']
            if trans not in by_transformation:
                by_transformation[trans] = []
            by_transformation[trans].append(result['results'])
        
        # Calculate averages
        for trans, results_list in by_transformation.items():
            if results_list:
                batch_results['summary_statistics'][trans] = {
                    'avg_exact_match_rate': np.mean([r['exact_match_rate'] for r in results_list]),
                    'avg_bleu1': np.mean([r['avg_bleu1'] for r in results_list]),
                    'avg_bleu4': np.mean([r['avg_bleu4'] for r in results_list]),
                    'avg_bit_accuracy': np.mean([r.get('avg_bit_accuracy', 0) for r in results_list]),
                    'total_processed': sum([r['total_processed'] for r in results_list]),
                    'total_successful': sum([r['successful_extractions'] for r in results_list]),
                    'directories_count': len(results_list)
                }
    
    # Save batch results
    batch_results_path = os.path.join(output_base_dir, 'batch_summary.json')
    with open(batch_results_path, 'w') as f:
        json.dump(batch_results, f, indent=2, cls=NumpyEncoder)
    
    # Save batch results as CSV
    batch_csv_path = os.path.join(output_base_dir, 'batch_summary.csv')
    df = pd.DataFrame(batch_results['processed_directories'])
    df.to_csv(batch_csv_path, index=False)
    
    print(f"\nBatch processing complete!")
    print(f"Results saved to: {output_base_dir}")
    
    return batch_results

def main():
    parser = argparse.ArgumentParser(description='Watermark Evaluation Tool')
    parser.add_argument('--mode', choices=['encode', 'extract', 'batch'], required=True,
                      help='Mode: encode messages, extract messages, or batch process')
    parser.add_argument('--texts_file', type=str,
                      help='Path to file containing texts to encode (one per line) - optional if using PKL/JSON loading')
    parser.add_argument('--watermarked_dir', type=str,
                      help='Directory containing watermarked images')
    parser.add_argument('--correspondence_file', type=str,
                      help='Path to correspondence JSON file')
    parser.add_argument('--output_dir', type=str, required=True,
                      help='Output directory')
    parser.add_argument('--model_type', choices=['videoseal', 'swift'], default='swift',
                      help='Type of watermarking model')
    parser.add_argument('--model_name', type=str, default='nautilus_256_900',
                      help='Name of the model configuration')
    parser.add_argument('--text_compressor', type=str, default='llmzip_opt125m',
                      help='Text compressor to use')
    parser.add_argument('--watermark_strength', type=float, default=1.0,
                      help='Base watermark strength (will be adjusted for PSNR targeting)')
    parser.add_argument('--target_psnr', type=float, default=42.0,
                      help='Target PSNR in dB (default: 42.0)')
    parser.add_argument('--psnr_tolerance', type=float, default=0.2,
                      help='PSNR tolerance in dB (default: 0.2)')
    parser.add_argument('--transformation', type=str, default='none',
                      help='Transformation to apply before detection')
    parser.add_argument('--num_texts', type=int, default=1000,
                      help='Number of texts to load/process')
    parser.add_argument('--num_images', type=int, default=1000,
                      help='Number of carrier images to use')
    parser.add_argument('--use_emu_edit_images', action='store_true', default=True,
                      help='Use EMU Edit images instead of COCO (default: True)')
    parser.add_argument('--use_coco_images', action='store_true',
                      help='Use COCO images instead of EMU Edit')
    parser.add_argument('--batch_dirs', type=str, nargs='+',
                      help='List of directories for batch processing')
    
    args = parser.parse_args()
    
    # Handle image source selection
    if args.use_coco_images:
        use_emu_edit_images = False
    else:
        use_emu_edit_images = args.use_emu_edit_images
    
    args = parser.parse_args()
    
    # Create model configuration
    if args.model_type == 'videoseal':
        model_config = {
            'type': 'videoseal',
            'watermark_strength': args.watermark_strength,
            'text_compressor': args.text_compressor
        }
    else:
        if args.model_name not in SWIFT_MODEL_CONFIGS:
            raise ValueError(f"Unknown Swift model: {args.model_name}")
        
        model_config = SWIFT_MODEL_CONFIGS[args.model_name].copy()
        model_config['text_compressor'] = args.text_compressor
        model_config['model_name'] = args.model_name  # Store model name for reference
    
    if args.mode == 'encode':
        # Load texts if provided, otherwise use PKL/JSON loading
        texts = None
        if args.texts_file:
            with open(args.texts_file, 'r') as f:
                texts = [line.strip() for line in f.readlines() if line.strip()]
            print(f"Loaded {len(texts)} texts from {args.texts_file}")
        
        # Encode messages
        results = encode_and_save_watermarked_images(
            output_dir=args.output_dir,
            model_config=model_config,
            num_texts=args.num_texts,
            num_images=args.num_images,
            target_psnr=args.target_psnr,
            psnr_tolerance=args.psnr_tolerance,
            use_emu_edit_images=use_emu_edit_images, 
            texts=texts
        )
        
    elif args.mode == 'extract':
        # Extract messages
        if not args.watermarked_dir:
            raise ValueError("--watermarked_dir is required for extract mode")
        if not args.correspondence_file:
            raise ValueError("--correspondence_file is required for extract mode")
        
        results = extract_and_compare_messages(
            watermarked_dir=args.watermarked_dir,
            original_correspondence_path=args.correspondence_file,
            model_config=model_config,
            transformation_name=args.transformation,
            output_dir=args.output_dir
        )
        
    elif args.mode == 'batch':
        # Batch process
        if not args.batch_dirs:
            raise ValueError("--batch_dirs is required for batch mode")
        
        model_configs = {
            'current_config': model_config
        }
        
        results = batch_process_directories(
            input_dirs=args.batch_dirs,
            model_configs=model_configs,
            transformations=[args.transformation],
            output_base_dir=args.output_dir
        )

if __name__ == "__main__":
    main()