import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import os
import random
import glob
import torch.nn.functional as F

class ImageCutout:
    """A class for performing Cutout on target images and filtering based on CLIP scores"""
    
    def __init__(self, clip_model, clip_preprocess, tokenizer, save_dir, text_embeddings, device, beta=10, seed=42):
        """
        Initialize the ImageCutout class
        
        Args:
            clip_model: CLIP model instance
            clip_preprocess: CLIP preprocessing function
            tokenizer: CLIP tokenizer
            save_dir (str): Directory to save Cutout images
            text_embeddings: Pre-computed text embeddings
            device: Computing device ('cuda' or 'cpu')
            beta (float): Beta distribution parameter for Cutout
            seed (int): Random seed
        """
        self.save_dir = save_dir
        self.beta = beta
        self.seed = seed
        self.device = device
        self.set_seed(seed)
        self.clip_model = clip_model
        self.clip_preprocess = clip_preprocess
        self.tokenizer = tokenizer
        self.text_embeddings = text_embeddings
        
    def set_seed(self, seed):
        """Set global random seed to ensure reproducibility"""
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
    
    def load_image(self, image_path):
        """
        Load image and convert to tensor format
        
        Args:
            image_path (str): Path to image file
            
        Returns:
            torch.Tensor: Transformed image tensor, shape [1, C, H, W]
        """
        image = Image.open(image_path).convert('RGB')
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),  # Convert to tensor, range [0,1]
        ])
        return transform(image).unsqueeze(0).to(self.device)  # Add batch dimension and move to device
    
    def cutout(self, image):
        """
        Perform Cutout on target image
        
        Args:
            image (torch.Tensor): Input image tensor
            
        Returns:
            torch.Tensor: Image tensor after Cutout operation
        """
        _, _, h, w = image.shape  # Get image height and width
        lam = np.random.beta(self.beta, self.beta)  # Generate lambda using Beta distribution
        bbox = self._rand_bbox(h, w, lam)  # Generate occlusion region boundaries

        # Create mask
        mask = torch.ones_like(image)
        mask[:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] = 0  # Set occlusion region to 0

        # Apply Cutout
        cutout_image = image * mask
        return cutout_image

    def _rand_bbox(self, h, w, lam):
        """
        Generate random occlusion region boundaries
        
        Args:
            h (int): Height of image
            w (int): Width of image
            lam (float): Lambda value from Beta distribution
            
        Returns:
            tuple: (y1, x1, y2, x2) coordinates of occlusion region
        """
        cut_rat = np.sqrt(1. - lam)  # Calculate width/height ratio
        cut_w = int(w * cut_rat)  # Occlusion region width
        cut_h = int(h * cut_rat)  # Occlusion region height

        # Randomly select center point
        cx = np.random.randint(w)
        cy = np.random.randint(h)

        # Calculate boundaries and constrain to image dimensions
        bbx1 = np.clip(cx - cut_w // 2, 0, w)
        bby1 = np.clip(cy - cut_h // 2, 0, h)
        bbx2 = np.clip(cx + cut_w // 2, 0, w)
        bby2 = np.clip(cy + cut_h // 2, 0, h)

        return bby1, bbx1, bby2, bbx2  # Return boundary coordinates
    
    def tensor_to_pil(self, tensor):
        """
        Convert tensor to PIL image
        
        Args:
            tensor (torch.Tensor): Input image tensor
            
        Returns:
            PIL.Image: Converted PIL image object
        """
        tensor = tensor.squeeze(0).cpu().mul(255).byte().permute(1, 2, 0).numpy()
        return Image.fromarray(tensor)
    
    def get_satisfactory_cutout(self, image, max_attempts=20):
        """
        Get Cutout image that satisfies CLIP conditions
        
        Args:
            image (torch.Tensor): Input image tensor
            max_attempts (int): Maximum number of attempts
            
        Returns:
            tuple: (Satisfactory cutout image tensor, CLIP similarity)
        """
        attempts = 0
        best_diff = float('-inf')
        best_cutout_image = None
        
        while attempts < max_attempts:
            cutout_image = self.cutout(image)
            
            # Calculate similarity difference for current image
            pil_image = self.tensor_to_pil(cutout_image)
            image_input = self.clip_preprocess(pil_image).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                image_feature = self.clip_model.encode_image(image_input, attn_method="head", normalize=False)
                image_feature = F.normalize(image_feature, dim=-1).detach().cpu().numpy()
                
                pos_sim = (100.0 * image_feature @ self.text_embeddings[0].T).squeeze(0)
                neg_sims = (100.0 * image_feature @ self.text_embeddings[1].T).squeeze(0)
                
                current_diff = max(np.max(pos_sim), neg_sims[0]) - neg_sims[1:].max()
                
                # Update best result
                if current_diff > best_diff:
                    best_diff = current_diff
                    best_cutout_image = cutout_image
                    best_pos_sim = pos_sim

            if current_diff > 0.05:  # Similarity difference threshold
                return cutout_image, best_pos_sim
                
            attempts += 1
            
        # If no satisfactory image is found, return the result with the maximum difference
        if best_cutout_image is not None:
            print(f"Warning: No satisfactory Cutout image found, returning best difference result (diff={best_diff:.4f})")
            return best_cutout_image, best_pos_sim
            
        raise ValueError(f"Failed to find any usable Cutout image after {max_attempts} attempts")
    
    def save_images(self, images, prefix):
        """
        Save images to specified directory
        
        Args:
            images (list): List of image tensors
            prefix (str): Filename prefix
        """
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        for i, img in enumerate(images):
            save_path = os.path.join(self.save_dir, f"{prefix}_{i}.png")
            save_image(img, save_path)
    
    def process(self, image_path, image_enable=False):
        """
        Main processing function: load image, perform Cutout, generate 5 augmented images and return similarity scores
        
        Args:
            image_path (str): Target image path
            image_enable (bool): Whether to save augmented images and return embeddings
            
        Returns:
            list: List of CLIP similarities for original and 5 augmented images
            list (optional): List of embeddings for 5 augmented images (when image_enable=True)
        """
        pos_sims = []
        
        # Load and calculate CLIP similarity for original image
        image = self.load_image(image_path)
        image_input = self.clip_preprocess(self.tensor_to_pil(image))
        image_feature = self.clip_model.encode_image(image_input.unsqueeze(0).to(self.device), attn_method="head", normalize=False)
        image_feature = F.normalize(image_feature, dim=-1).detach().cpu().numpy()
        pos_sim0 = (100.0 * image_feature @ self.text_embeddings[0].T).squeeze(0)
        pos_sims.append(pos_sim0)
        
        # Generate 5 augmented images and their similarities
        if image_enable:
            mixed_images = []
            for i in range(5):
                mixed_image, pos_sim = self.get_satisfactory_cutout(image)
                mixed_images.append(mixed_image)
                pos_sims.append(pos_sim)
            # Save original and augmented images
            self.save_images([image], "image")
            self.save_images(mixed_images, "cutout")  # Changed from "augmix" to "cutout"
            # Return similarities and embeddings
            embeddings = [img.squeeze(0) for img in mixed_images]
            return pos_sims, embeddings
        else:
            for i in range(5):
                _, pos_sim = self.get_satisfactory_cutout(image)
                pos_sims.append(pos_sim)
            return pos_sims
