import torch
import torch.nn.functional as F


def get_mix_image(c_fg, c_bg, alpha):
    """
    Generate mixed counterfactual image c_mix.
    
    Args:
        c_fg (torch.Tensor): Foreground embedding.
        c_bg (torch.Tensor): Background/scene embedding.
        alpha (float): Mixing coefficient (weight for foreground).
        
    Returns:
        torch.Tensor: Normalized mixed embedding.
    """
    # Generate mixed counterfactual image c_mix
    if c_fg.dim() == 2:
        c_fg = c_fg.unsqueeze(1)  # [batch_size, 1, dim]
    c_mix = alpha * c_fg + (1-alpha) * c_bg
    return F.normalize(c_mix, dim=-1)  

def select_scene_embeddings(c_z, c_xj, scene_embeddings, selct_scene_num):
    """
    Select the most suitable scene embeddings based on dissimilarity to both 
    background and foreground embeddings.
    
    Args:
        c_z (torch.Tensor): Background embeddings of shape [batch_size, dim].
        c_xj (torch.Tensor): Foreground embeddings of shape [batch_size, dim].
        scene_embeddings (torch.Tensor): Scene embeddings pool of shape [num_scenes, dim].
        selct_scene_num (int): Number of scenes to select.
        
    Returns:
        torch.Tensor: Selected scene embeddings of shape [batch_size, selct_scene_num, dim].
    """
    batch_size, dim = c_z.shape
    num_scenes = scene_embeddings.shape[0]
    
    scene_embeddings = scene_embeddings.unsqueeze(0).expand(batch_size, num_scenes, dim)
    # Calculate dissimilarity between background and scene
    d_z = 1 - F.cosine_similarity(scene_embeddings, c_z.unsqueeze(1).expand(batch_size, num_scenes, dim), dim=-1)  # [batch_size, num_scenes]
    
    # Calculate dissimilarity between foreground and scene
    d_x = 1 - F.cosine_similarity(scene_embeddings, c_xj.unsqueeze(1).expand(batch_size, num_scenes, dim), dim=-1)  # [batch_size, num_scenes]
    
    # Combined score - considers dissimilarity to both background and foreground
    scores = d_z + d_x  # Can adjust weights as needed
    
    # Select K scenes with highest scores
    _, best_idx = torch.topk(scores, selct_scene_num, dim=1)
    selected_scenes = torch.gather(scene_embeddings, dim=1, index=best_idx.unsqueeze(-1).expand(-1, -1, dim))
    return selected_scenes

def process_batch(f_data, text_embeddings, f_places, index_list, args, MAX_K):
    """
    Process a batch of data to compute TDE (Textual Decomposition and Enhancement) scores.
    
    This function computes both base TDE scores and counterfactual TDE scores by:
    1. Extracting token-level and image-level features
    2. Computing background and foreground embeddings
    3. Generating counterfactual mixed images using scene embeddings
    4. Computing final TDE scores
    
    Args:
        f_data (dict): Dictionary containing 'attentions' and 'mlps' tensors.
        text_embeddings (torch.Tensor): Text embeddings for classes.
        f_places (dict or None): Dictionary containing scene embeddings.
        index_list (list or range): Indices of the batch samples.
        args: Arguments object containing parameters such as lam_hat, alpha, scene_type, etc.
        MAX_K (int): Maximum number of top classes to consider.
        
    Returns:
        tuple: (tde_base, tde_scores)
            - tde_base: Base TDE scores before counterfactual augmentation
            - tde_scores: Final TDE scores after counterfactual augmentation
    """
    lam_hat = args.lam_hat
    alpha = args.alpha
    attentions_batch = torch.from_numpy(f_data["attentions"][index_list]).to(text_embeddings.device)
    mlps_batch = torch.from_numpy(f_data["mlps"][index_list]).to(text_embeddings.device)
    
    batch_size, num_tokens, _ = attentions_batch.shape
    num_tokens = num_tokens - 1
    num_classes = text_embeddings.shape[0]
    
    tokens_effect = attentions_batch[:,1:,:] + (torch.sum(torch.mean(mlps_batch[:,:,:],dim=0),dim=0)/num_tokens + torch.mean(attentions_batch[:,0,:],dim=0)/num_tokens).repeat(batch_size,num_tokens,1)
    
    # Compute per-token, per-class sigmoid scores
    logits_token = 100 * torch.matmul(tokens_effect, text_embeddings.T) # [batch_size, num_tokens,2]
    probs_tok = torch.sigmoid(logits_token)   # Independent probability of each token for each class
    
    # Compute background probability per token (one-vs-rest)
    p_fg_max = probs_tok.max(dim=-1).values   # [batch_size, num_tokens] - A token is considered foreground if it has high probability for any class
    probs_bg = 1 - p_fg_max
    
    # Compute weights for background embedding
    w_bg = probs_bg.clone()
    w_bg[w_bg <= 0.2] = 0        # Only tokens with sufficiently high P(z|c_t) are considered
    w_bg = w_bg.unsqueeze(-1)   # [batch_size, num_tokens, 1]  
    
    c_z = (tokens_effect * w_bg).sum(dim=1) / (w_bg.sum(dim=1) + 1e-6)               # [batch_size, dim]
    c_z = F.normalize(c_z, dim=-1)
    
    logits_c_z = 100 * torch.matmul(c_z, text_embeddings.T)
    probs_c_z = torch.softmax(logits_c_z, dim=-1)
    
    # Compute original image-level class probabilities
    logits_img = 100 * torch.matmul(attentions_batch.sum(dim=1) + mlps_batch.sum(dim=1), text_embeddings.T)
    probs_img = torch.softmax(logits_img, dim=-1)
    
    tde_base = logits_img - lam_hat * logits_c_z
    tde_base_norm = torch.softmax(tde_base, dim=-1)
    
    # Select top-K classes per image
    topk_vals, topk_indices = torch.topk(probs_img, min(MAX_K, num_classes), dim=-1)
    
    # Compute TDE scores for each class
    batch_idx = torch.arange(batch_size, device=topk_indices.device).unsqueeze(1)
    tde_scores = torch.zeros(batch_size, num_classes, device=topk_indices.device)  # [batch_size, num_classes]
    
    for j in range(topk_indices.shape[1]): # Only consider the top K classes as candidates
        cls_idx = topk_indices[:,j]
        cls_idx_exp = cls_idx.view(batch_size, 1, 1).expand(-1, num_tokens, 1)  # [batch_size, num_tokens, 1]

        # Build class-specific pure-foreground embedding c_xj
        p_cls = probs_tok.gather(dim=2, index=cls_idx_exp)
        w_cls = p_cls.clone()
        w_cls[w_cls <= 0.6] = 0 
        
        # # Visualization (optional)
        # draw_id = 0
        # from visualization import visual_segmentation_process
        # visual_segmentation_process(draw_id, w_cls[draw_id].flatten(), args)
        
        c_xj = (tokens_effect * w_cls).sum(dim=1) / (w_cls.sum(dim=1) + 1e-6)
        c_xj = F.normalize(c_xj, dim=-1)
        
        if args.scene_type == "outer_cz":
            scene_names = f_places["scene_names"][:]
            # Process each sample
            selected_scenes_pool = []
            for sc in scene_names.tolist():
                scene_embeddings = torch.from_numpy(f_places["scenes"][sc]["embeddings"][:]).to(text_embeddings.device)  
                selected_scenes = select_scene_embeddings(c_z, c_xj, scene_embeddings,args.select_scene_num)
                selected_scenes_pool.append(selected_scenes)
            all_selected_scenes = torch.cat(selected_scenes_pool, dim=1)
        elif args.scene_type == "inner_cz":
            all_selected_scenes = select_scene_embeddings(c_z, c_xj, c_z, min(args.select_scene_num,len(cls_idx)))
        elif args.scene_type == "virtual_cz":
            if args.virtual_scene_embeddings is None:
                raise ValueError("virtual_scene_embeddings is not set")
            all_selected_scenes = select_scene_embeddings(c_z, c_xj, args.virtual_scene_embeddings, args.select_scene_num)
        elif args.scene_type == "random_cz":
            random_cz = torch.randn(400, 1024).to(text_embeddings.device)
            all_selected_scenes = select_scene_embeddings(c_z, c_xj, random_cz, args.select_scene_num)
            
        c_mix = get_mix_image(c_xj.unsqueeze(1), all_selected_scenes, alpha)
                
        # Calculate class probabilities for mixed image
        logits_c_mix = 100 * torch.matmul(c_mix, text_embeddings.T)
        probs_c_mix = torch.softmax(logits_c_mix, dim=-1)
        
        logits_scene = 100 * torch.matmul(all_selected_scenes, text_embeddings.T)
        probs_scene = torch.softmax(logits_scene, dim=-1)
        # Calculate TDE scores
        tde_mix = torch.mean(logits_c_mix - lam_hat * logits_scene, dim=1)

        # Write back to corresponding positions in tde_scores
        tde_scores[batch_idx, cls_idx] = tde_mix[batch_idx, cls_idx]                # For each b: tde_scores[b, cls_idx[b]] = tde_j[b]
    
    # tde_scores_norm = torch.softmax(tde_scores, dim=-1)
    return tde_base, tde_scores