import torch
import random
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from sklearn import metrics
import torch.nn.functional as F

size = 256

transform = transforms.Compose(
        [
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            lambda x: 2.0 * x - 1.0,  # normalize to [-1, 1]
        ]
    )

if torch.cuda.is_available():
    device="cuda"
else:
    device="cpu"

def pil_list_to_tensor(image_list, device='cuda', normalize_taming=False):
    """
    Converts a list of PIL images to a single tensor [B, C, H, W].
    If normalize_taming=True, scales from [0,1] to [-1,1].
    """
    tensor_list = []
    for img in image_list:
        img = img.convert("RGB")
        tensor = transforms.ToTensor()(img)  # [C,H,W] in [0,1]
        if normalize_taming:
            tensor = 2.0 * tensor - 1.0  # scale to [-1,1]
        # Ensure batch dimension exists
        if tensor.ndim == 3:
            tensor = tensor.unsqueeze(0)  # [1,C,H,W]
        tensor_list.append(tensor)

    # Concatenate along batch dimension
    batch_tensor = torch.cat(tensor_list, dim=0).to(device)  # [B,C,H,W]
    return batch_tensor

def set_seeds(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


def update_weights(model, ckpt_path, delta=True):  # Deltas!
    state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]

    if delta:
        state_dict_to_apply = model.state_dict().copy()
        for key in state_dict:
            if key in state_dict_to_apply:
                state_dict_to_apply[key] = state_dict_to_apply[key] + state_dict[key].to(
                    state_dict_to_apply[key].device
                )
            else:
                state_dict_to_apply[key] = state_dict[key]
    else:
        state_dict_to_apply = state_dict

    missing, unexpected = model.load_state_dict(state_dict_to_apply, strict=False)
    print(f"Missing: {missing}")
    print(f"Unexpected: {unexpected}")



def preprocess_rar(image):
    
    if not image.mode == "RGB":
        image = image.convert("RGB")
    
    transform = transforms.Compose(
    [
        transforms.Resize((size, size)),  # resize to 256x256
        transforms.ToTensor(),
    ]) 
    image = transform(image).to("cuda")
    
    return image.unsqueeze(0)

def preprocess(image):
    
    if not image.mode == "RGB":
        image = image.convert("RGB")
    
    transform = transforms.Compose(
    [
        transforms.Resize((size, size)),  # resize to 256x256
        transforms.ToTensor(),
        lambda x: 2.0 * x - 1.0,  # normalize to [-1, 1]
    ]) 
    image = transform(image).to("cuda")
    
    return image.unsqueeze(0)

def postprocess_rar(image):

    image = image.squeeze(0)

    image = image.detach().cpu()
    # print(image.shape)
    image = image.permute(1, 2, 0)
    
    image = (image * 255).clamp(0, 255).numpy().astype('uint8')
    
    # print(image.shape)

    return Image.fromarray(image)

def postprocess_taming(image):

    image = image.squeeze(0)
    image = (image + 1.0) / 2.0

    image = image.detach().cpu()
    # print(image.shape)
    image = image.permute(1, 2, 0)
    
    image = (image * 255).clamp(0, 255).numpy().astype('uint8')
    
    # print(image.shape)

    return Image.fromarray(image)
    

def evaluate(rar_scores, other_scores):
    all_labels = np.concatenate([np.zeros(len(other_scores)), np.ones(len(rar_scores))])
    all_scores = np.concatenate([other_scores, rar_scores])
    all_scores_inverted = -all_scores

    fpr, tpr, threshold_inverted = metrics.roc_curve(all_labels, all_scores_inverted)
    auc = metrics.auc(fpr, tpr)
    acc = np.max(1 - (fpr + (1 - tpr))/2)

    idx = np.where(fpr < 0.01)[0][-1]
    threshold_at_1fpr = -threshold_inverted[idx]
    tpr_at_1fpr = tpr[idx]
    
    print(f"Threshold at 1% FPR: {threshold_at_1fpr:.4f}, AUC: {auc:.4f}, Acc: {acc:.4f}, TPR at 1% FPR: {tpr_at_1fpr:.4f}")

    return threshold_at_1fpr, auc, acc, tpr_at_1fpr


def collate_images_only(batch):
    images = []
    

    for item in batch:
        img_tensor, token = item  # img_tensor is already a torch.Tensor

        # If grayscale, repeat channel to make 3 channels
        if img_tensor.ndim == 3 and img_tensor.shape[0] == 1:
            img_tensor = img_tensor.repeat(3, 1, 1)  # [3,H,W]

        images.append(img_tensor)

        # Handle tokens
    images_tensor = torch.stack(images)
    return images_tensor



def preprocess_rar_batch(image_batch):
    """Preprocess a batch of images for RAR model"""
    processed_batch = []
    for image in image_batch:
        if not image.mode == "RGB":
            image = image.convert("RGB")
        
        transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
        ])
        processed_batch.append(transform(image))
    
    return torch.stack(processed_batch).to("cuda")

def preprocess_batch(image_batch):
    """Preprocess a batch of images for Taming model"""
    processed_batch = []
    for image in image_batch:
        if not image.mode == "RGB":
            image = image.convert("RGB")
        
        transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            lambda x: 2.0 * x - 1.0,  # normalize to [-1, 1]
        ])
        processed_batch.append(transform(image))
    
    return torch.stack(processed_batch).to("cuda")

def tokenize_and_reconstruct_rar_batch(image_batch, vqgan, use_quant=True):
    """Batch version of tokenize_and_reconstruct_rar"""
    
    processed_batch = image_batch
    
    batch_size = processed_batch.shape[0] if processed_batch.ndim == 4 else 1
    
    if use_quant:
        encoded_tokens, hidden_states, quantized_states, codebook_loss = vqgan.encode_with_internals(processed_batch)
    else:
        encoded_tokens, hidden_states, quantized_states, codebook_loss = vqgan.encode_without_quant(processed_batch)
    
    # Compute codebook loss with MSE for each image in batch
    codebook_loss_mse = torch.mean((hidden_states - quantized_states) ** 2, dim=[1, 2, 3])
    
    # Compute image reconstruction loss
    if use_quant:
        reconstructed_batch = vqgan.decode_tokens(encoded_tokens.clone())
    else:
        reconstructed_batch = vqgan.decode_states(hidden_states.clone())
    
    reconstructed_batch = torch.clamp(reconstructed_batch, 0.0, 1.0)
    img_rec_loss_mse = torch.mean((reconstructed_batch.cpu() - processed_batch.cpu()) ** 2, dim=[1, 2, 3])
    
    # Convert back to PIL Images
    reconstructed_images = []
    for i in range(batch_size):
        img_array = (reconstructed_batch[i].clone() * 255.0).permute(1, 2, 0).to("cpu", dtype=torch.uint8).numpy()
        reconstructed_images.append(Image.fromarray(img_array))
    
    return encoded_tokens, reconstructed_images, codebook_loss_mse, img_rec_loss_mse

def tokenize_and_reconstruct_taming_batch(processed_batch, vqgan, use_quant=True):
    """Batch version of tokenize_and_reconstruct_taming"""
    batch_size = processed_batch.shape[0] if processed_batch.ndim == 4 else 1

    if use_quant:
        with torch.no_grad():
            h = vqgan.encoder(processed_batch)
            z = vqgan.quant_conv(h)
            z_q, emb_loss, info = vqgan.quantize(z)
            codebook_loss = torch.mean((z - z_q) ** 2, dim=[1, 2, 3])
            h = vqgan.post_quant_conv(z_q)
            img_rec = vqgan.decoder(h)
            img_rec_loss_mse = torch.mean((img_rec.cpu() - processed_batch.cpu()) ** 2, dim=[1, 2, 3])
            img_rec = (img_rec + 1.0) / 2.0  # denormalize to [0, 1]
            img_rec = img_rec.clamp(0.0, 1.0)
    else:
        z_q = None
        codebook_loss = torch.zeros(batch_size).to(processed_batch.device)
        with torch.no_grad():
            h = vqgan.encoder(processed_batch)
            z = vqgan.quant_conv(h)
            h = vqgan.post_quant_conv(z)
            img_rec = vqgan.decoder(h)
            img_rec_loss_mse = torch.mean((img_rec.cpu() - processed_batch.cpu()) ** 2, dim=[1, 2, 3])
            img_rec = (img_rec + 1.0) / 2.0  # denormalize to [0, 1]
            img_rec = img_rec.clamp(0.0, 1.0)
    
    # Convert back to PIL Images
    reconstructed_images = []
    for i in range(batch_size):
        img_array = img_rec[i].permute(1, 2, 0).cpu().numpy()
        img_array = (img_array * 255).astype(np.uint8)
        reconstructed_images.append(Image.fromarray(img_array))
    
    return z_q, reconstructed_images, codebook_loss, img_rec_loss_mse

def calculate_losses_dataset_batched(dataset, dataset_name, vqgan, model_name, args, batch_size=16, get_overlap=False):
    """
    Batched version of calculate_losses_dataset
    """
    # Create DataLoader for batching
    
    match model_name:
        case "taming":
            transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor(),
                lambda x: 2.0 * x - 1.0,  # normalize to [-1, 1]
            ])
        case "rar":
            transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor(),
            ])
        case _:
            raise ValueError()

    if dataset.transform:
        dataset.transform = transforms.Compose([
            dataset.transform,
            transform
        ])
    else:
        dataset.transform = transform


    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_images_only)
    
    # Initialize lists to store results
    codebook_loss_mses, rec_loss_mses, rec_loss_mse_no_quantize = [], [], []
    codebook_loss_mses_double, rec_loss_mses_double, rec_loss_mse_no_quantize_double = [], [], []
    codebook_loss_mses_double_ratio, rec_loss_mses_double_ratio, rec_loss_mse_no_quantize_double_ratio = [], [], []
    overlapping_ratios = []
    combined = []
    
    
        
    for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=f"Processing {dataset_name}")):
        

        image_batch = batch_data.to(device)

        if image_batch.ndim == 3:
            image_batch = image_batch.unsqueeze(0) 

        
        if model_name == "rar":
            z_q, reconstructed_images, codebook_loss_mse, img_rec_loss_mse = tokenize_and_reconstruct_rar_batch(
                image_batch, vqgan, use_quant=True
            )
            _, reconstructed_images_no_quantize, _, img_rec_loss_mse_no_quantize = tokenize_and_reconstruct_rar_batch(
                image_batch, vqgan, use_quant=False
            )
        elif model_name == "taming":
            z_q, reconstructed_images, codebook_loss_mse, img_rec_loss_mse = tokenize_and_reconstruct_taming_batch(
                image_batch, vqgan, use_quant=True
            )
            _, reconstructed_images_no_quantize, _, img_rec_loss_mse_no_quantize = tokenize_and_reconstruct_taming_batch(
                image_batch, vqgan, use_quant=False
            )
        else:
            raise ValueError(f"Model {model_name} not supported for batching")

        
        assert len(codebook_loss_mse) == image_batch.shape[0], f"{len(codebook_loss_mse), image_batch.shape}"

        # Store first reconstruction results
        codebook_loss_mses.extend(codebook_loss_mse.cpu().numpy())
        rec_loss_mses.extend(img_rec_loss_mse.cpu().numpy())
        rec_loss_mse_no_quantize.extend(img_rec_loss_mse_no_quantize.cpu().numpy())
        
        # Second reconstruction (batched)
        if model_name == "taming":
            reconstructed_images_tensor = pil_list_to_tensor(reconstructed_images, device=device, normalize_taming=True)
            reconstructed_images_no_quantize_tensor = pil_list_to_tensor(reconstructed_images_no_quantize, device=device, normalize_taming=True)

            reencoded_tokens_double, reconstructed_images_double, codebook_loss_mse_double, img_rec_loss_mse_double = tokenize_and_reconstruct_taming_batch(
                reconstructed_images_tensor, vqgan, use_quant=True
            )
            _, reconstructed_images_no_quantize_double, _, img_rec_loss_mse_no_quantize_double = tokenize_and_reconstruct_taming_batch(
                reconstructed_images_no_quantize_tensor, vqgan, use_quant=False
            )

        # For rar
        elif model_name == "rar":
            reconstructed_images_tensor = pil_list_to_tensor(reconstructed_images, device=device, normalize_taming=False)
            reconstructed_images_no_quantize_tensor = pil_list_to_tensor(reconstructed_images_no_quantize, device=device, normalize_taming=False)

            reencoded_tokens_double, reconstructed_images_double, codebook_loss_mse_double, img_rec_loss_mse_double = tokenize_and_reconstruct_rar_batch(
                reconstructed_images_tensor, vqgan, use_quant=True
            )
            _, reconstructed_images_no_quantize_double, _, img_rec_loss_mse_no_quantize_double = tokenize_and_reconstruct_rar_batch(
                reconstructed_images_no_quantize_tensor, vqgan, use_quant=False
            )
        
        # Store second reconstruction results
        codebook_loss_mses_double.extend(codebook_loss_mse_double.cpu().numpy())
        rec_loss_mses_double.extend(img_rec_loss_mse_double.cpu().numpy())
        rec_loss_mse_no_quantize_double.extend(img_rec_loss_mse_no_quantize_double.cpu().numpy())
        
        # Calculate ratios (element-wise)
        codebook_loss_mse_cpu = codebook_loss_mse.cpu()
        img_rec_loss_mse_cpu = img_rec_loss_mse.cpu()
        img_rec_loss_mse_no_quantize_cpu = img_rec_loss_mse_no_quantize.cpu()
        codebook_loss_mse_double_cpu = codebook_loss_mse_double.cpu()
        img_rec_loss_mse_double_cpu = img_rec_loss_mse_double.cpu()
        img_rec_loss_mse_no_quantize_double_cpu = img_rec_loss_mse_no_quantize_double.cpu()
        
        # Calculate ratios with proper handling of division by zero
        codebook_ratio = torch.where(
            codebook_loss_mse_double_cpu != 0,
            codebook_loss_mse_cpu / codebook_loss_mse_double_cpu,
            torch.zeros_like(codebook_loss_mse_cpu)
        )
        
        rec_ratio = torch.where(
            img_rec_loss_mse_double_cpu != 0,
            img_rec_loss_mse_cpu / img_rec_loss_mse_double_cpu,
            torch.zeros_like(img_rec_loss_mse_cpu)
        )
        
        rec_no_quant_ratio = torch.where(
            img_rec_loss_mse_no_quantize_double_cpu != 0,
            img_rec_loss_mse_no_quantize_cpu / img_rec_loss_mse_no_quantize_double_cpu,
            torch.zeros_like(img_rec_loss_mse_no_quantize_cpu)
        )
        
        codebook_loss_mses_double_ratio.extend(codebook_ratio.numpy())
        rec_loss_mses_double_ratio.extend(rec_ratio.numpy())
        rec_loss_mse_no_quantize_double_ratio.extend(rec_no_quant_ratio.numpy())
        
        # Combined metric
        combined_batch = codebook_loss_mse_cpu * rec_no_quant_ratio
        combined.extend(combined_batch.numpy())
    
    # Print results (same as original)
    print(f'[Results for {dataset_name}]')
    print(f"[1st Rec] Average codebook loss MSE for {dataset_name} images: {np.mean(codebook_loss_mses)}")
    print(f"[1st Rec] Average reconstruction loss MSE for {dataset_name} images: {np.mean(rec_loss_mses)}")
    print(f"[1st Rec] Average reconstruction loss no quantize MSE for {dataset_name} images: {np.mean(rec_loss_mse_no_quantize)}")
    
    print(f"[2nd Rec] Average codebook loss MSE for {dataset_name} images: {np.mean(codebook_loss_mses_double)}")
    print(f"[2nd Rec] Average reconstruction loss MSE for {dataset_name} images: {np.mean(rec_loss_mses_double)}")
    print(f"[2nd Rec] Average reconstruction loss MSE no quantize for {dataset_name} images: {np.mean(rec_loss_mse_no_quantize_double)}")
    
    print(f"[2nd Rec ratio] Average codebook loss ratio MSE for {dataset_name} images: {np.mean(codebook_loss_mses_double_ratio)}")
    print(f"[2nd Rec ratio] Average reconstruction loss ratio MSE for {dataset_name} images: {np.mean(rec_loss_mses_double_ratio)}")
    print(f"[2nd Rec ratio] Average reconstruction loss ratio MSE no quantize for {dataset_name} images: {np.mean(rec_loss_mse_no_quantize_double_ratio)}")
    
    print(f"Average Combined Value for {dataset_name} images: {np.mean(combined)}")
    
    if get_overlap:
        print(f"Average overlapping ratio for {dataset_name} images: {np.mean(overlapping_ratios)}")
    
    return (codebook_loss_mses, rec_loss_mses, codebook_loss_mses_double, rec_loss_mses_double, 
            codebook_loss_mses_double_ratio, rec_loss_mses_double_ratio, overlapping_ratios, 
            rec_loss_mse_no_quantize, rec_loss_mse_no_quantize_double, rec_loss_mse_no_quantize_double_ratio, 
            combined)



def evaluate_real_and_gen_datasets(real_dataset_values, gen_dataset_values, attack="None"):
    
    codebook_loss_mses, rec_loss_mses, codebook_loss_mses_double, rec_loss_mses_double, codebook_loss_mses_double_ratio, rec_loss_mses_double_ratio, overlapping_ratios, rec_loss_mse_no_quantize, rec_loss_mse_no_quantize_double, rec_loss_mse_no_quantize_double_ratio, combined= gen_dataset_values
    
    codebook_loss_mses_real, rec_loss_mses_real, codebook_loss_mses_double_real, rec_loss_mses_double_real, codebook_loss_mses_double_ratio_real, rec_loss_mses_double_ratio_real, overlapping_ratios_real, rec_loss_mse_no_quantize_real, rec_loss_mse_no_quantize_double_real, rec_loss_mse_no_quantize_double_ratio_real, combined_real = real_dataset_values
    
    
    row = {}

    print(f"Using {len(rec_loss_mses)} number of values")
    
    print("[1st] Rec Loss")
    row[("rec_loss_1", "Threshold")], row[("rec_loss_1", "AUC")], row[("rec_loss_1", "ACC")], row[("rec_loss_1", "TPR@1%FPR")] = evaluate(np.array(rec_loss_mses), np.array(rec_loss_mses_real))


    print("Rec Loss Ratio")
    row[("rec_loss_ratio", "Threshold")], row[("rec_loss_ratio", "AUC")], row[("rec_loss_ratio", "ACC")], row[("rec_loss_ratio", "TPR@1%FPR")] = evaluate(np.array(rec_loss_mses_double_ratio), np.array(rec_loss_mses_double_ratio_real))
    

    print("[1st] Rec Loss no Quantize")
    row[("rec_loss_no_quantize", "Threshold")], row[("rec_loss_no_quantize", "AUC")], row[("rec_loss_no_quantize", "ACC")], row[("rec_loss_no_quantize", "TPR@1%FPR")] = \
    evaluate(np.array(rec_loss_mse_no_quantize), np.array(rec_loss_mse_no_quantize_real))


    print("Rec no Quantize Ratio")
    row[("rec_loss_no_quantize_ratio", "Threshold")], row[("rec_loss_no_quantize_ratio", "AUC")], row[("rec_loss_no_quantize_ratio", "ACC")], row[("rec_loss_no_quantize_ratio", "TPR@1%FPR")] = \
        evaluate(np.array(rec_loss_mse_no_quantize_double_ratio), np.array(rec_loss_mse_no_quantize_double_ratio_real))

    # codebook_loss variants
    print("[1st] Codebook Loss")
    row[("codebook_loss", "Threshold")], row[("codebook_loss", "AUC")], row[("codebook_loss", "ACC")], row[("codebook_loss", "TPR@1%FPR")] = \
        evaluate(np.array(codebook_loss_mses), np.array(codebook_loss_mses_real))

    print("Codebook Ratio")
    row[("codebook_loss_ratio", "Threshold")], row[("codebook_loss_ratio", "AUC")], row[("codebook_loss_ratio", "ACC")], row[("codebook_loss_ratio", "TPR@1%FPR")] = \
        evaluate(np.array(codebook_loss_mses_double_ratio), np.array(codebook_loss_mses_double_ratio_real))

    # combined method
    print("Combined")
    row[("combined", "Threshold")], row[("combined", "AUC")], row[("combined", "ACC")], row[("combined", "TPR@1%FPR")] = \
        evaluate(np.array(combined), np.array(combined_real))

    # rec_loss additional variants
    print("[2nd] rec loss")
    row[("rec_loss_double", "Threshold")], row[("rec_loss_double", "AUC")], row[("rec_loss_double", "ACC")], row[("rec_loss_double", "TPR@1%FPR")] = \
        evaluate(np.array(rec_loss_mses_double), np.array(rec_loss_mses_double_real))
    print("[2nd] rec loss no Quantize")
    row[("rec_loss_no_quantize_double", "Threshold")], row[("rec_loss_no_quantize_double", "AUC")], row[("rec_loss_no_quantize_double", "ACC")], row[("rec_loss_no_quantize_double", "TPR@1%FPR")] = \
        evaluate(np.array(rec_loss_mse_no_quantize_double), np.array(rec_loss_mse_no_quantize_double_real))
    print("[2nd] Codebook loss")
    row[("codebook_loss_double", "Threshold")], row[("codebook_loss_double", "AUC")], row[("codebook_loss_double", "ACC")], row[("codebook_loss_double", "TPR@1%FPR")] = \
        evaluate(np.array(codebook_loss_mses_double), np.array(codebook_loss_mses_double_real))

    return row
    



def tokenize_and_reconstruct_rar(original_image_batch, vqgan, display_img=False, use_quant=True):

    image_batch = preprocess_rar(original_image_batch)
    

    if use_quant:
        encoded_tokens, hidden_states, quantized_states, codebook_loss = vqgan.encode_with_internals(image_batch.clone().to(device))
    else:
        encoded_tokens, hidden_states, quantized_states, codebook_loss = vqgan.encode_without_quant(image_batch.clone().to(device))
    # compute codebook loss with MSE
    codebook_loss_mse = torch.mean((hidden_states - quantized_states) ** 2, dim=[1, 2, 3])

    # compute image reconstruction loss
    if use_quant:
        reconstructed_image_batch = vqgan.decode_tokens(encoded_tokens.clone())
    else:
        reconstructed_image_batch = vqgan.decode_states(hidden_states.clone())
    reconstructed_image_batch = torch.clamp(reconstructed_image_batch, 0.0, 1.0)
    img_rec_loss_mse = torch.mean((reconstructed_image_batch.cpu() - image_batch.cpu()) ** 2, dim=[1, 2, 3])

    reconstructed_image_batch_show = (reconstructed_image_batch.clone() * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0]
    reconstructed_image_batch_show = Image.fromarray(reconstructed_image_batch_show)

    return encoded_tokens, reconstructed_image_batch_show, codebook_loss_mse, img_rec_loss_mse

def tokenize_and_reconstruct_taming(img, vqgan, display_img=False, use_quant=True):
    """Tokenize an image and reconstruct it."""
    
    if not img.mode == "RGB":
        img = img.convert("RGB")
    # display(img)
    img_display = img.copy()
    img = transform(img).to("cuda")
    
    if use_quant:
        with torch.no_grad():
            h = vqgan.encoder(img.unsqueeze(0))
            z = vqgan.quant_conv(h)
            z_q, emb_loss, info = vqgan.quantize(z)

            # compute L2 distance
            codebook_loss = torch.mean((z - z_q) ** 2, dim=[1, 2, 3])

            h = vqgan.post_quant_conv(z_q)
            img_rec = vqgan.decoder(h)
            img_rec_loss_mse = torch.mean((img_rec.cpu() - img.cpu()) ** 2, dim=[1, 2, 3])
            img_rec = (img_rec + 1.0) / 2.0  # denormalize to [0, 1]
            img_rec = img_rec.clamp(0.0, 1.0)
            img_rec = img_rec.squeeze(0).permute(1, 2, 0).cpu().numpy()
            img_rec = (img_rec * 255).astype(np.uint8)
            
            img_rec = Image.fromarray(img_rec)
    else:
        z_q = None
        codebook_loss = None
        with torch.no_grad():
            h = vqgan.encoder(img.unsqueeze(0))
            z = vqgan.quant_conv(h)
            
            h = vqgan.post_quant_conv(z)
            
            
            img_rec = vqgan.decoder(h)
            
            img_rec_loss_mse = torch.mean((img_rec.cpu() - img.cpu()) ** 2, dim=[1, 2, 3])
            img_rec = (img_rec + 1.0) / 2.0  # denormalize to [0, 1]
            img_rec = img_rec.clamp(0.0, 1.0)
            img_rec = img_rec.squeeze(0).permute(1, 2, 0).cpu().numpy()
            img_rec = (img_rec * 255).astype(np.uint8)
            
            img_rec = Image.fromarray(img_rec)



    if display_img:
        img_display.show()
        img_rec.show()
    return z_q, img_rec, codebook_loss, img_rec_loss_mse




def calculate_losses_dataset(dataset, dataset_name, vqgan, model_name, get_overlap=False):
    '''
    *Input:
    dataset
    dataset_name
    get_overlap (optional): whether to compute overlapping ratios (if this is generated image)
    *Return:
    codebook loss, rec loss,
    codebook loss double, rec loss double,
    codebook loss double ratio, rec loss double ratio,
    overlapping ratio (optional)
    '''
    codebook_loss_mses, rec_loss_mses, rec_loss_mse_no_quantize = [], [], []
    codebook_loss_mses_double, rec_loss_mses_double, rec_loss_mse_no_quantize_double = [], [], []
    codebook_loss_mses_double_ratio, rec_loss_mses_double_ratio, rec_loss_mse_no_quantize_double_ratio = [], [], []
    overlapping_ratios = []
    combined = []
    for i, (image, tokens) in enumerate(tqdm(dataset)):

        z_q, reconstructed_image, codebook_loss_mse, img_rec_loss_mse = eval(f"tokenize_and_reconstruct_{model_name}(image, vqgan, display_img=display_img, use_quant=True)")
        
        _, reconstructed_image_no_quantize, _, img_rec_loss_mse_no_quantize = eval(f"tokenize_and_reconstruct_{model_name}(image, vqgan, display_img=display_img, use_quant=False)")
        

        codebook_loss_mses.append(codebook_loss_mse.item())
        rec_loss_mses.append(img_rec_loss_mse.item())
        rec_loss_mse_no_quantize.append(img_rec_loss_mse_no_quantize.item())
        
        # double reconstruction and get ratio
        reencoded_tokens_double, reconstructed_image_double, codebook_loss_mse_double, img_rec_loss_mse_double = eval(f"tokenize_and_reconstruct_{model_name}(reconstructed_image, vqgan, display_img=display_img, use_quant=True)")
        
        _, reconstructed_image_no_quantize_double, _, img_rec_loss_mse_no_quantize_double = eval(f"tokenize_and_reconstruct_{model_name}(reconstructed_image_no_quantize, vqgan, display_img=display_img, use_quant=False)")
        
        
        codebook_loss_mses_double.append(codebook_loss_mse_double.item())
        rec_loss_mses_double.append(img_rec_loss_mse_double.item())
        rec_loss_mse_no_quantize_double.append(img_rec_loss_mse_no_quantize_double.item())
        codebook_loss_mse_double_ratio = codebook_loss_mse / codebook_loss_mse_double if codebook_loss_mse != 0 else 0
        img_rec_loss_mse_double_ratio = img_rec_loss_mse / img_rec_loss_mse_double if img_rec_loss_mse != 0 else 0
        img_rec_loss_mse_no_quantize_double_ratio = img_rec_loss_mse_no_quantize / img_rec_loss_mse_no_quantize_double if img_rec_loss_mse_no_quantize != 0 else 0
        
        codebook_loss_mses_double_ratio.append(codebook_loss_mse_double_ratio.item())
        rec_loss_mses_double_ratio.append(img_rec_loss_mse_double_ratio.item())
        rec_loss_mse_no_quantize_double_ratio.append(img_rec_loss_mse_no_quantize_double_ratio.item())
        
        combined.append(codebook_loss_mse.item() * img_rec_loss_mse_no_quantize_double_ratio.item())
        
    print(f'[Results for {dataset_name}]')
    print(f"[1st Rec] Average codebook loss MSE for {dataset_name} images: {np.mean(codebook_loss_mses)}")
    print(f"[1st Rec] Average reconstruction loss MSE for {dataset_name} images: {np.mean(rec_loss_mses)}")
    print(f"[1st Rec] Average reconstruction loss no quantize MSE for {dataset_name} images: {np.mean(rec_loss_mse_no_quantize)}")
    
    print(f"[2nd Rec] Average codebook loss MSE for {dataset_name} images: {np.mean(codebook_loss_mses_double)}")
    print(f"[2nd Rec] Average reconstruction loss MSE for {dataset_name} images: {np.mean(rec_loss_mses_double)}")
    print(f"[2nd Rec] Average reconstruction loss MSE no quantize for {dataset_name} images: {np.mean(rec_loss_mse_no_quantize_double)}")
    
    
    print(f"[2nd Rec ratio] Average codebook loss ratio MSE for {dataset_name} images: {np.mean(codebook_loss_mses_double_ratio)}")
    print(f"[2nd Rec ratio] Average reconstruction loss ratio MSE for {dataset_name} images: {np.mean(rec_loss_mses_double_ratio)}")
    print(f"[2nd Rec ratio] Average reconstruction loss ratio MSE no quantize for {dataset_name} images: {np.mean(rec_loss_mse_no_quantize_double_ratio)}")
    
    
    print(f"Average Combined Value for {dataset_name} images: ")
    
    if get_overlap:
        print(f"Average overlapping ratio for {dataset_name} images: {np.mean(overlapping_ratios)}")
    return codebook_loss_mses, rec_loss_mses, codebook_loss_mses_double, rec_loss_mses_double, codebook_loss_mses_double_ratio, rec_loss_mses_double_ratio, overlapping_ratios, rec_loss_mse_no_quantize, rec_loss_mse_no_quantize_double, rec_loss_mse_no_quantize_double_ratio, combined


def encode_to_quantized(img, vqgan, display_img=False):
    """Tokenize an image and reconstruct it."""

    
    with torch.no_grad():
        h = vqgan.encoder(img)
        z = vqgan.quant_conv(h)
        z_q, emb_loss, info = vqgan.quantize(z)
        
    return None, None, z_q, None

def tokenize_and_reconstruct_batch_latent_optim(original_image_batch, tokenizer, args, display_img=False, use_quant=True, lr=1e-2, iters=50):
    image_batch = original_image_batch.clone()
    # initialize the states
    match args.model:
        case "taming":
            if use_quant:
                encoded_tokens, hidden_states, quantized_states, codebook_loss = encode_to_quantized(image_batch.clone().to(device), vqgan=tokenizer)
            else:
                encoded_tokens, hidden_states, quantized_states, codebook_loss = tokenizer.encode_without_quant(image_batch.clone().to(device))
        case "rar":
            if use_quant:
                encoded_tokens, hidden_states, quantized_states, codebook_loss = tokenizer.encode_with_internals(image_batch.clone().to(device))
            else:
                encoded_tokens, hidden_states, quantized_states, codebook_loss = tokenizer.encode_without_quant(image_batch.clone().to(device))
    # fhat_optim = torch.nn.Parameter(tokenizer.post_quant_conv(quantized_states).clone().detach()).cuda()
    fhat_optim = torch.nn.Parameter(quantized_states.clone().detach()).cuda()
    optimizer = torch.optim.Adam([fhat_optim], lr=lr)
    for i in range(iters):
        optimizer.zero_grad()
        if args.model =="taming":
            rec_gen_img = tokenizer.decoder(tokenizer.post_quant_conv(fhat_optim))
        else:
            rec_gen_img = tokenizer.decoder(fhat_optim)

        if args.model =="taming":
            rec_gen_img = torch.clamp(rec_gen_img, -1.0, 1.0)
        else:
            rec_gen_img = torch.clamp(rec_gen_img, 0.0, 1.0)

        loss = F.mse_loss(rec_gen_img, image_batch.clone())
        loss.backward()
        optimizer.step()
        # print(loss)
        if i%50==0:
            for g in optimizer.param_groups:
                g['lr'] = g['lr']*0.5

    return rec_gen_img, fhat_optim