from typing import List, Optional, Union
import torch
import torchvision.transforms as transforms
import numpy as np
import clip
import torchvision
from PIL import Image
import os

def zeroshot_classifier(model, classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for i, classname in enumerate(classnames):
            texts = [template.format(classname) for template in templates] #format with class
            texts = clip.tokenize(texts).cuda() #tokenize
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding = class_embedding / class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).detach().cuda()
    return zeroshot_weights

def img_process(images, img_size):
    """
    Batch process images
    
    Args:
        images: input image tensor, shape (N, C, H, W)
        img_size: image size
    
    Returns:
        batch_image: processed image tensor, shape (N, C, 224, 224)
    """
    roiAlign = torchvision.ops.RoIAlign(
        output_size=224, 
        sampling_ratio=-1, 
        spatial_scale=1, 
        aligned=True
    )
    
    batch_size = images.shape[0]
    
    # 创建批量坐标 - 保持为tensor格式
    coords = torch.tensor([[0.0, 0.0, float(img_size), float(img_size)]]).to(images.device).to(images.dtype)
    coords = coords.repeat(batch_size, 1)  # 重复batch_size次
    
    # 为每个图像添加batch索引
    batch_indices = torch.arange(batch_size, device=images.device, dtype=images.dtype).unsqueeze(1)
    coords_with_batch = torch.cat([batch_indices, coords], dim=1)
    
    # 批量ROI对齐
    batch_image = roiAlign(images, coords_with_batch)
    
    return batch_image

# def img_process(images, img_size):
#     roiAlign = torchvision.ops.RoIAlign(output_size=224, sampling_ratio = -1, spatial_scale=1, aligned = True)
#     batch_image = []
#     coord = torch.tensor([[0.0,0.0,float(img_size),float(img_size)]]).cuda().to(torch.float16)
#     for i in range(images.shape[0]):
#         image = images[i].unsqueeze(0)
#         image = roiAlign(image, [coord]).squeeze()
#         batch_image.append(image)
#     batch_image = torch.stack(batch_image, dim=0)
#     return batch_image

class CLIPZeroShotClassifier:
    def __init__(
            self, 
            classes: List[str], 
            templates: List[str],
            img_size: int,
            model_name: str = 'ViT-B/16',
            device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
    ):
        self.device = device
        self.classifier, self.preprocess = clip.load(model_name, device=device)
        self.img_size = img_size
        self.zeroshot_weights = zeroshot_classifier(
            self.classifier, 
            classes, 
            templates
        ).to(self.device)

    def get_logits(
            self, 
            images: torch.Tensor, 
            batch_size: Optional[int] = None
    ) -> torch.Tensor:
        """
        Batch process images to get logits
        
        Args:
            images: input image tensor, shape (N, C, H, W)
            batch_size: batch size, if None process all images at once
        
        Returns:
            logits: classification logits, shape (N, num_classes)
        """
        images = images.to(device=self.device, dtype=self.classifier.dtype)
        
        if batch_size is None or images.shape[0] <= batch_size:
            images_tensor = img_process(images, self.img_size)
            image_features = self.classifier.encode_image(images_tensor)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            logits = 100. * image_features @ self.zeroshot_weights
            return logits
        
        all_logits = []
        for i in range(0, images.shape[0], batch_size):
            batch_images = images[i:i+batch_size]
            batch_images_tensor = img_process(batch_images, self.img_size)
            batch_image_features = self.classifier.encode_image(batch_images_tensor)
            batch_image_features = batch_image_features / batch_image_features.norm(dim=-1, keepdim=True)
            batch_logits = 100. * batch_image_features @ self.zeroshot_weights
            all_logits.append(batch_logits)
        
        return torch.cat(all_logits, dim=0)

def save_pil_image(image, clean_text, successful, adv_text):

    file_path = os.path.join('images',clean_text, successful)
    
    if not os.path.exists(file_path):
        os.makedirs(file_path)
    adv_text = clean_filename(adv_text)
    save_path = os.path.join(file_path,adv_text+'.png')

    image[0].save(save_path)
    #print(save_path)


import string

def clean_filename(filename):
    valid_chars = f"-_.() {string.ascii_letters}{string.digits}"
    cleaned_filename = "".join(c for c in filename if c in valid_chars)
    max_length = 100
    if len(cleaned_filename) > max_length:
        cleaned_filename = cleaned_filename[:max_length]
    return cleaned_filename

