import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import os
import random
import glob
import torch.nn.functional as F

class ImageCutMix:
    """A class for performing CutMix between target images and random scene images, and filtering based on CLIP scores"""
    
    def __init__(self, base_dir, clip_model, clip_preprocess, tokenizer, save_dir, text_embeddings, device, beta=10, seed=42):
        """
        Initialize the ImageCutMix class
        
        Args:
            base_dir (str): Base directory for scene images
            clip_model: CLIP model instance
            clip_preprocess: CLIP preprocessing function
            tokenizer: CLIP tokenizer
            save_dir (str): Directory to save mixed images
            text_embeddings: Pre-computed text embeddings
            device: Computing device ('cuda' or 'cpu')
            beta (float): Beta distribution parameter for CutMix
            seed (int): Random seed
        """
        self.base_dir = base_dir
        self.save_dir = save_dir
        self.beta = beta
        self.seed = seed
        self.device = device
        self.set_seed(seed)
        self.clip_model = clip_model
        self.clip_preprocess = clip_preprocess
        self.tokenizer = tokenizer
        self.text_embeddings = text_embeddings
        
    def set_seed(self, seed):
        """Set global random seed to ensure reproducibility"""
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
    
    def load_image(self, image_path):
        """
        Load image and convert to tensor format
        
        Args:
            image_path (str): Path to image file
            
        Returns:
            torch.Tensor: Transformed image tensor, shape [1, C, H, W]
        """
        image = Image.open(image_path).convert('RGB')
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),  # Convert to tensor, range [0,1]
        ])
        return transform(image).unsqueeze(0).to(self.device)  # Add batch dimension and move to device
    
    def random_scene_image(self, scene_name):
        """
        Randomly select an image from the specified scene folder
        
        Args:
            scene_name (str): Name of the scene folder
            
        Returns:
            str: Path to the randomly selected image
        """
        scene_dir = os.path.join(self.base_dir, scene_name)
        image_paths = glob.glob(os.path.join(scene_dir, "*.jpg"))  # Assuming images are in jpg format
        if not image_paths:
            raise ValueError(f"No images found in {scene_dir}")
        return random.choice(image_paths)
    
    def cutmix(self, image, scene_image):
        """
        Perform CutMix on target image and scene image
        
        Args:
            image (torch.Tensor): Target image tensor
            scene_image (torch.Tensor): Scene image tensor
            
        Returns:
            tuple: (Mixed image tensor, lambda value)
        """
        # Ensure images have the same size
        assert image.shape == scene_image.shape, "Image sizes do not match"
        
        # Generate crop region
        _, _, h, w = image.shape
        lam = np.random.beta(self.beta, self.beta)  # Value typically between [0.4-0.6]
        bbox = self._rand_bbox(h, w, lam)
        
        # Create mask
        mask = torch.zeros_like(image)
        mask[:, :, bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
        
        # Apply CutMix
        mixed_image = image * (1 - mask) + scene_image * mask
        return mixed_image, lam
    
    def _rand_bbox(self, h, w, lam):
        """
        Generate random bounding box for cropping
        
        Args:
            h (int): Height of image
            w (int): Width of image
            lam (float): Lambda value from Beta distribution
            
        Returns:
            tuple: (y1, x1, y2, x2) coordinates of bounding box
        """
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(w * cut_rat)
        cut_h = int(h * cut_rat)
        
        # Randomly select crop center
        cx = np.random.randint(w)
        cy = np.random.randint(h)
        
        # Calculate crop boundaries
        bbx1 = np.clip(cx - cut_w // 2, 0, w)
        bby1 = np.clip(cy - cut_h // 2, 0, h)
        bbx2 = np.clip(cx + cut_w // 2, 0, w)
        bby2 = np.clip(cy + cut_h // 2, 0, h)
        
        return bby1, bbx1, bby2, bbx2
    
    def tensor_to_pil(self, tensor):
        """
        Convert tensor to PIL image
        
        Args:
            tensor (torch.Tensor): Input image tensor
            
        Returns:
            PIL.Image: Converted PIL image object
        """
        tensor = tensor.squeeze(0).cpu().mul(255).byte().permute(1, 2, 0).numpy()
        return Image.fromarray(tensor)
    
    def get_satisfactory_cutmix(self, image, scene_name, max_attempts=20):
        """
        Get CutMix image that satisfies CLIP conditions
        
        Args:
            image (torch.Tensor): Input image tensor
            scene_name (str): Name of the scene to use
            max_attempts (int): Maximum number of attempts
            
        Returns:
            tuple: (Satisfactory mixed image tensor, used scene image tensor, CLIP similarity)
        """
        attempts = 0
        best_diff = float('-inf')
        best_mixed_image = None
        best_scene_image = None
        
        while attempts < max_attempts:
            scene_path = self.random_scene_image(scene_name)
            scene_image = self.load_image(scene_path)
            mixed_image, lam = self.cutmix(image, scene_image)
            
            # Calculate similarity difference for current image
            pil_image = self.tensor_to_pil(mixed_image)
            image_input = self.clip_preprocess(pil_image).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                image_feature = self.clip_model.encode_image(image_input, attn_method="head", normalize=False)
                image_feature = F.normalize(image_feature, dim=-1).detach().cpu().numpy()
                
                pos_sim = (100.0 * image_feature @ self.text_embeddings[0].T).squeeze(0)
                neg_sims = (100.0 * image_feature @ self.text_embeddings[1].T).squeeze(0)
                
                current_diff = max(np.max(pos_sim), neg_sims[0]) - neg_sims[1:].max()
                
                # Update best result
                if current_diff > best_diff:
                    best_diff = current_diff
                    best_mixed_image = mixed_image
                    best_scene_image = scene_image
                    best_pos_sim = pos_sim

            if current_diff > 0.05:  # Similarity difference threshold
                return mixed_image, scene_image, pos_sim
                
            attempts += 1
            
        # If no satisfactory image is found, return the result with the maximum difference
        if best_mixed_image is not None:
            print(f"Warning: No satisfactory CutMix image found, returning best difference result (diff={best_diff:.4f})")
            return best_mixed_image, best_scene_image, best_pos_sim
            
        raise ValueError(f"Failed to find any usable CutMix image for scene {scene_name} after {max_attempts} attempts")
    
    def save_images(self, images, prefix):
        """
        Save images to specified directory
        
        Args:
            images (list): List of image tensors
            prefix (str): Filename prefix
        """
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        for i, img in enumerate(images):
            save_path = os.path.join(self.save_dir, f"{prefix}_{i}.png")
            save_image(img, save_path)
    
    def process(self, image_path, scene_list, image_enable=False):
        """
        Main processing function: load image, select images from scene folders, 
        perform CutMix, filter based on CLIP scores and save
        
        Args:
            image_path (str): Target image path
            scene_list (list): List of scene names
            image_enable (bool): Whether to return embeddings of mixed images
            
        Returns:
            list: List of CLIP similarities for original and mixed images
            list (optional): List of embeddings for mixed images (when image_enable=True)
        """
        pos_sims = []
        
        image = self.load_image(image_path) 
        image_input = self.clip_preprocess(self.tensor_to_pil(image))
        image_feature = self.clip_model.encode_image(image_input.unsqueeze(0).to(self.device), attn_method="head", normalize=False)
        image_feature = F.normalize(image_feature, dim=-1).detach().cpu().numpy()
        pos_sim0 = (100.0 * image_feature @ self.text_embeddings[0].T).squeeze(0)
        pos_sims.append(pos_sim0)
        
        if image_enable:
            mixed_images = []
            scene_images_used = []
            for scene in scene_list:
                mixed_image, scene_image, pos_sim = self.get_satisfactory_cutmix(image, scene)
                mixed_images.append(mixed_image)
                scene_images_used.append(scene_image)   
                pos_sims.append(pos_sim)
            self.save_images([image], "image")
            self.save_images(scene_images_used, "scene")
            self.save_images(mixed_images, "cutmix")
            embeddings = [img.squeeze(0) for img in mixed_images]
            return pos_sims, embeddings
        else:
            for scene in scene_list:
                _, _, pos_sim = self.get_satisfactory_cutmix(image, scene)
                pos_sims.append(pos_sim)
            return pos_sims