import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import os
import random
import torch.nn.functional as F

class ImageAugMix:
    """A class for performing AugMix augmentation on target images and filtering based on CLIP scores"""
    
    def __init__(self, clip_model, clip_preprocess, tokenizer, save_dir, text_embeddings, device, severity=3, width=3, depth=-1, alpha=1.0, seed=42):
        """
        Initialize the ImageAugMix class
        
        Args:
            clip_model: CLIP model instance
            clip_preprocess: CLIP preprocessing function
            tokenizer: CLIP tokenizer
            save_dir (str): Directory to save augmented images
            text_embeddings: Pre-computed text embeddings
            device: Computing device ('cuda' or 'cpu')
            severity (int): Intensity of augmentation operations
            width (int): Width of augmentation chain (number of parallel chains)
            depth (int): Depth of augmentation chain (number of operations in each chain, if -1 randomly selects 1-3)
            alpha (float): Parameter for Dirichlet distribution, used for mixing weights
            seed (int): Random seed
        """
        self.save_dir = save_dir
        self.severity = severity
        self.width = width
        self.depth = depth
        self.alpha = alpha
        self.seed = seed
        self.device = device
        self.set_seed(seed)
        self.clip_model = clip_model
        self.clip_preprocess = clip_preprocess
        self.tokenizer = tokenizer
        self.text_embeddings = text_embeddings
        self.augmentations = self.get_augmentations()
        
    def set_seed(self, seed):
        """Set global random seed to ensure reproducibility"""
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
    
    def load_image(self, image_path):
        """
        Load image and convert to tensor format
        
        Args:
            image_path (str): Path to image file
            
        Returns:
            torch.Tensor: Transformed image tensor, shape [1, C, H, W]
        """
        image = Image.open(image_path).convert('RGB')
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),  # Convert to tensor, range [0,1]
        ])
        return transform(image).unsqueeze(0).to(self.device)  # Add batch dimension and move to device
    
    def get_augmentations(self):
        """
        Define augmentation operations for AugMix
        
        Returns:
            list: List of image augmentation operations
        """
        return [
            transforms.RandomHorizontalFlip(),  # Random horizontal flip
            transforms.RandomVerticalFlip(),    # Random vertical flip
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),  # Color jitter
            transforms.RandomRotation(30),      # Random rotation
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random affine transformation
            # More augmentation operations can be added as needed
        ]
    
    def augmix(self, image):
        """
        Perform AugMix augmentation on target image
        
        Args:
            image (torch.Tensor): Input image tensor
            
        Returns:
            torch.Tensor: AugMix augmented image tensor
        """
        mixed_images = []
        for _ in range(self.width):
            # Randomly select depth augmentation operations
            if self.depth <= 0:
                depth = random.randint(1, 3)
            else:
                depth = self.depth
            ops = [random.choice(self.augmentations) for _ in range(depth)]
            # Apply augmentation operations sequentially
            augmented = image
            for op in ops:
                augmented = op(augmented)
            mixed_images.append(augmented)
        
        # Generate mixing weights using Dirichlet distribution and mix augmented images
        weights = np.random.dirichlet([self.alpha] * self.width)
        mixed_image = sum(w * img for w, img in zip(weights, mixed_images))
        return mixed_image
    
    def tensor_to_pil(self, tensor):
        """
        Convert tensor to PIL image
        
        Args:
            tensor (torch.Tensor): Input image tensor
            
        Returns:
            PIL.Image: Converted PIL image object
        """
        tensor = tensor.squeeze(0).cpu().mul(255).byte().permute(1, 2, 0).numpy()
        return Image.fromarray(tensor)
    
    def get_satisfactory_augmix(self, image, max_attempts=10):
        """
        Get AugMix image that satisfies CLIP conditions
        
        Args:
            image (torch.Tensor): Input image tensor
            max_attempts (int): Maximum number of attempts
            
        Returns:
            tuple: (Satisfactory augmented image tensor, corresponding CLIP similarity)
        """
        attempts = 0
        best_diff = float('-inf')
        best_mixed_image = None
        
        while attempts < max_attempts:
            mixed_image = self.augmix(image)
            
            # Calculate CLIP similarity difference for augmented image
            pil_image = self.tensor_to_pil(mixed_image)
            image_input = self.clip_preprocess(pil_image).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                image_feature = self.clip_model.encode_image(image_input, attn_method="head", normalize=False)
                image_feature = F.normalize(image_feature, dim=-1).detach().cpu().numpy()
                
                pos_sim = (100.0 * image_feature @ self.text_embeddings[0].T).squeeze(0)
                neg_sims = (100.0 * image_feature @ self.text_embeddings[1].T).squeeze(0)
                
                current_diff = max(np.max(pos_sim), neg_sims[0]) - neg_sims[1:].max()
                
                # Update best result
                if current_diff > best_diff:
                    best_diff = current_diff
                    best_mixed_image = mixed_image
                    best_pos_sim = pos_sim

            if current_diff > 0.05:  # Similarity difference threshold
                return mixed_image, pos_sim
                
            attempts += 1
            
        # If no satisfactory image is found, return the result with the maximum difference
        if best_mixed_image is not None:
            print(f"Warning: No satisfactory AugMix image found, returning best difference result (diff={best_diff:.4f})")
            return best_mixed_image, best_pos_sim
            
        raise ValueError(f"Failed to find any usable AugMix image after {max_attempts} attempts")
    
    def save_images(self, images, prefix):
        """
        Save images to specified directory
        
        Args:
            images (list): List of image tensors
            prefix (str): Filename prefix
        """
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        for i, img in enumerate(images):
            save_path = os.path.join(self.save_dir, f"{prefix}_{i}.png")
            save_image(img, save_path)
    
    def process(self, image_path, image_enable=False):
        """
        Main processing function: load image, perform AugMix, filter based on CLIP scores and save
        
        Args:
            image_path (str): Target image path
            image_enable (bool): Whether to return embeddings of augmented images
            
        Returns:
            list: List of CLIP similarities for original and augmented images
            list (optional): List of embeddings for augmented images (when image_enable=True)
        """
        pos_sims = []
        
        # Load and calculate CLIP similarity for original image
        image = self.load_image(image_path) 
        image_input = self.clip_preprocess(self.tensor_to_pil(image))
        image_feature = self.clip_model.encode_image(image_input.unsqueeze(0).to(self.device), attn_method="head", normalize=False)
        image_feature = F.normalize(image_feature, dim=-1).detach().cpu().numpy()
        pos_sim0 = (100.0 * image_feature @ self.text_embeddings[0].T).squeeze(0)
        pos_sims.append(pos_sim0)
        
        if image_enable:
            # Generate 5 augmented images
            mixed_images = []
            for i in range(5):
                mixed_image, pos_sim = self.get_satisfactory_augmix(image)
                mixed_images.append(mixed_image)
                pos_sims.append(pos_sim)
            
            # Save original and augmented images
            self.save_images([image], "image")
            self.save_images(mixed_images, "augmix")
            
            # Return embeddings of all images
            embeddings = [img.squeeze(0) for img in mixed_images]
            return pos_sims, embeddings
        else:
            # Return similarities only
            for i in range(5):
                _, pos_sim = self.get_satisfactory_augmix(image)
                pos_sims.append(pos_sim)
            return pos_sims