import random
from tqdm import tqdm

import torch

# extract lowest and highest patches
def get_lowest_highest_patches(
    dataloader: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    output_size: int,
    patch_size: int = 64,
    stride: int = 16,
    num_patches: int = 4,
    return_images: bool = False):
    """
    Extract patches from images using unfold and find ones that maximize model output.
    
    Args:
        dataloader: PyTorch dataloader containing images
        model: Neural network model to evaluate patches
        output_size: Number of output channels
        patch_size: Size of square patches to extract
        stride: Stride of patches
        num_patches: Number of top patches to keep
    
    Returns:
        Returns patches and the corresponding scores
    """
    model = model.cuda()
    model.eval()
    
    # List to store (patch, score) tuples
    highest_patches_scores = torch.ones((output_size, num_patches)).cuda() * (-float('inf')) # nr_channels(output_size) x nr_patches
    lowest_patches_scores = torch.ones((output_size, num_patches)).cuda() * float('inf')

    highest_patches = torch.ones((output_size, num_patches, 3, patch_size, patch_size)).cuda() # nr_channels(output_size) x nr_patches x patch size(3,H,W)
    lowest_patches = torch.ones((output_size, num_patches, 3, patch_size, patch_size)).cuda() # nr_channels(output_size) x nr_patches x patch size(3,H,W)

    # they store the index of the top activating images relative to the loader dataset
    if return_images:
        highest_images = torch.ones((output_size, num_patches)).cuda() # nr_channels(output_size) x nr_patches 
        lowest_images = torch.ones((output_size, num_patches)).cuda() # nr_channels(output_size) x nr_patches


    with torch.no_grad():
        print('Finding patches...')
        iteration = 0
        for images, label in tqdm(dataloader):
            images = images.cuda()
            
            batch_size, channels, height, width = images.shape
            assert batch_size == 1

            kc, kh, kw = 3, patch_size, patch_size  # kernel size
            dc, dh, dw = 1, stride, stride  # stride
            
            # Pad to multiples of 32
            #x = F.pad(x, (x.size(2)%kw // 2, x.size(2)%kw // 2,
            #              x.size(1)%kh // 2, x.size(1)%kh // 2,
            #              x.size(0)%kc // 2, x.size(0)%kc // 2))
            
            patches = images.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
            patches = patches.contiguous().view(-1, kc, kh, kw)
            
             # get model output for this batch of patches
            outputs = model(patches)
            #print('Outputs', outputs.shape)
            topk_vals, topk_indices = torch.topk(outputs, num_patches, dim=0)
            bottomk_vals, bottomk_indices = torch.topk(outputs, num_patches, dim=0, largest=False)

            #print('Outputs', topk_vals.shape)

            topk_vals = topk_vals.permute(1,0) # output_size x num_patches
            bottomk_vals = bottomk_vals.permute(1,0) # output_size x num_patches

            topk_indices = topk_indices.permute(1,0) # output_size x num_patches
            bottomk_indices = bottomk_indices.permute(1,0) # output_size x num_patches

            topk_patches = patches[topk_indices]
            bottomk_patches = patches[bottomk_indices]
            
            # to see if the current topk are better than global ones, we concatenate them, and find topk again
            #print('highest_patches_scores', highest_patches_scores)
            #print('topk_vals', topk_vals)
            tmp_scores_highest = torch.cat([highest_patches_scores, topk_vals], dim=1)
            tmp_patches_highest = torch.cat([highest_patches, topk_patches], dim=1)
            if return_images:
                tmp_images_highest = torch.cat([highest_images, torch.tensor([[iteration]]).cuda().repeat(output_size, num_patches)], dim=1)
            top_patches_scores, topk_indices = torch.topk(tmp_scores_highest, num_patches, dim=1)
            #print('top_patches_scores', top_patches_scores)
            #print('topk_indices', topk_indices)

            tmp_scores_lowest = torch.cat([lowest_patches_scores, bottomk_vals], dim=1)
            tmp_patches_lowest = torch.cat([lowest_patches, bottomk_patches], dim=1)
            if return_images:
                tmp_images_lowest = torch.cat([lowest_images, torch.tensor([[iteration]]).cuda().repeat(output_size, num_patches)], dim=1)
            bottom_patches_scores, bottomk_indices = torch.topk(tmp_scores_lowest, num_patches, dim=1, largest=False)

            # expand the indices to allow broadcasting
            batch_indices = torch.arange(output_size).unsqueeze(1).expand(-1, num_patches).cuda()  # Shape: [output_size, num_patches]
            
            top_patches = tmp_patches_highest[batch_indices,topk_indices,:,:,:]
            bottom_patches = tmp_patches_lowest[batch_indices,bottomk_indices,:,:,:]

            if return_images:
                top_images = tmp_images_highest[batch_indices.cpu(),topk_indices.cpu()]
                bottom_images = tmp_images_lowest[batch_indices.cpu(),bottomk_indices.cpu()]


            # update
            highest_patches_scores = top_patches_scores
            lowest_patches_scores = bottom_patches_scores
            highest_patches = top_patches
            lowest_patches = bottom_patches
            if return_images:
                highest_images = top_images
                lowest_images = bottom_images

            iteration += 1

    if return_images:
        return bottom_patches, bottom_patches_scores, lowest_images, top_patches, top_patches_scores, highest_images
    else:
        return bottom_patches, bottom_patches_scores, top_patches, top_patches_scores