

import torch
import wandb


def merge_reconstructions_on_similarity(reconstructions, soft_pred_masks, threshold=0.9, do_log=False):
    
    batch_size, num_slots, C, H, W = reconstructions.shape
    batch_size, num_slots, H, W = soft_pred_masks.shape
    
    flat_masks = soft_pred_masks.view(batch_size, num_slots, H * W)
    normed_masks = flat_masks / torch.sum(flat_masks, dim=2, keepdim=True)
    dots = torch.einsum('bnd,bmd->bnm', normed_masks, normed_masks)
    wandb.log({"merge_reconstructions_dotprods": wandb.Histogram(dots.cpu().numpy().flatten())})
    
    shall_merge = torch.where(dots > threshold, True, False)
    
    for batch in range(batch_size):
        for slot in range(num_slots):
            for second_slot in range(slot):  # only merge with previous slots
                if shall_merge[batch, slot, second_slot]:
                    # weighted merge of reconstructions
                    reconstructions[batch, slot] = (soft_pred_masks[batch, slot] * reconstructions[batch, slot] + soft_pred_masks[batch, slot] * reconstructions[batch, second_slot]) / \
                    (soft_pred_masks[batch, slot] + soft_pred_masks[batch, second_slot])
                    # set second slot to zero
                    reconstructions[batch, second_slot] = torch.zeros_like(reconstructions[batch, second_slot])
                    # merge masks
                    soft_pred_masks[batch, slot] = soft_pred_masks[batch, slot] + soft_pred_masks[batch, second_slot]
                    soft_pred_masks[batch, second_slot] = torch.zeros_like(soft_pred_masks[batch, second_slot])
    
    return reconstructions, soft_pred_masks