import random
from tqdm import tqdm

import torch

# This model takes original model and outputs layer of interest
class ModelWrapper(torch.nn.Module):
    def __init__(self, model, layer_of_interest, architecture='CNN'):
        super(ModelWrapper, self).__init__()

        self.model = model
        self.activations = {}
        self.architecture = architecture
    
        # Define a forward hook to capture the activations
        def forward_hook(module, input, output):
            self.activations["output"] = output  # Save the output of the layer
            
        # Register the hook
        hook_handle = layer_of_interest.register_forward_hook(forward_hook)

        self.downsample = torch.nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        _ = self.model(x)
        if self.architecture == 'CNN':
            x = self.downsample(self.activations["output"]).squeeze(-1).squeeze(-1)
        elif self.architecture == 'Transformer':
            x = self.activations["output"].mean(dim=1)

        return x



# remove inplace operations
def ReLU_inplace_to_False(module):
    for layer in module._modules.values():
        if isinstance(layer, torch.nn.ReLU):
            layer.inplace = False
        ReLU_inplace_to_False(layer)


# extract lowest and highest patches
def get_lowest_highest_patches(
    dataloader: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    output_size: int,
    patch_size: int = 64,
    stride: int = 16,
    num_patches: int = 4,
    return_images: bool = False):
    """
    Extract patches from images using unfold and find ones that maximize model output.
    
    Args:
        dataloader: PyTorch dataloader containing images
        model: Neural network model to evaluate patches
        output_size: Number of output channels
        patch_size: Size of square patches to extract
        stride: Stride of patches
        num_patches: Number of top patches to keep
    
    Returns:
        Returns patches and the corresponding scores
    """
    model = model.cuda()
    model.eval()
    
    # List to store (patch, score) tuples
    highest_patches_scores = torch.ones((output_size, num_patches)).cuda() * (-float('inf')) # nr_channels(output_size) x nr_patches
    lowest_patches_scores = torch.ones((output_size, num_patches)).cuda() * float('inf')

    highest_patches = torch.ones((output_size, num_patches, 3, patch_size, patch_size)).cuda() # nr_channels(output_size) x nr_patches x patch size(3,H,W)
    lowest_patches = torch.ones((output_size, num_patches, 3, patch_size, patch_size)).cuda() # nr_channels(output_size) x nr_patches x patch size(3,H,W)

    # they store the index of the top activating images relative to the loader dataset
    if return_images:
        highest_images = torch.ones((output_size, num_patches)).cuda() # nr_channels(output_size) x nr_patches 
        lowest_images = torch.ones((output_size, num_patches)).cuda() # nr_channels(output_size) x nr_patches


    with torch.no_grad():
        print('Finding patches...')
        iteration = 0
        for images, label in tqdm(dataloader):
            images = images.cuda()
            
            batch_size, channels, height, width = images.shape
            assert batch_size == 1

            kc, kh, kw = 3, patch_size, patch_size  # kernel size
            dc, dh, dw = 1, stride, stride  # stride
            
            patches = images.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
            patches = patches.contiguous().view(-1, kc, kh, kw)
            
             # get model output for this batch of patches
            outputs = model(patches)
            topk_vals, topk_indices = torch.topk(outputs, num_patches, dim=0)
            bottomk_vals, bottomk_indices = torch.topk(outputs, num_patches, dim=0, largest=False)

            topk_vals = topk_vals.permute(1,0) # output_size x num_patches
            bottomk_vals = bottomk_vals.permute(1,0) # output_size x num_patches

            topk_indices = topk_indices.permute(1,0) # output_size x num_patches
            bottomk_indices = bottomk_indices.permute(1,0) # output_size x num_patches

            topk_patches = patches[topk_indices]
            bottomk_patches = patches[bottomk_indices]
            
            # to see if the current topk are better than global ones, we concatenate them, and find topk again
            tmp_scores_highest = torch.cat([highest_patches_scores, topk_vals], dim=1)
            tmp_patches_highest = torch.cat([highest_patches, topk_patches], dim=1)
            if return_images:
                tmp_images_highest = torch.cat([highest_images, torch.tensor([[iteration]]).cuda().repeat(output_size, num_patches)], dim=1)
            top_patches_scores, topk_indices = torch.topk(tmp_scores_highest, num_patches, dim=1)

            tmp_scores_lowest = torch.cat([lowest_patches_scores, bottomk_vals], dim=1)
            tmp_patches_lowest = torch.cat([lowest_patches, bottomk_patches], dim=1)
            if return_images:
                tmp_images_lowest = torch.cat([lowest_images, torch.tensor([[iteration]]).cuda().repeat(output_size, num_patches)], dim=1)
            bottom_patches_scores, bottomk_indices = torch.topk(tmp_scores_lowest, num_patches, dim=1, largest=False)

            # expand the indices to allow broadcasting
            batch_indices = torch.arange(output_size).unsqueeze(1).expand(-1, num_patches).cuda()  # Shape: [output_size, num_patches]
            
            top_patches = tmp_patches_highest[batch_indices,topk_indices,:,:,:]
            bottom_patches = tmp_patches_lowest[batch_indices,bottomk_indices,:,:,:]

            if return_images:
                top_images = tmp_images_highest[batch_indices.cpu(),topk_indices.cpu()]
                bottom_images = tmp_images_lowest[batch_indices.cpu(),bottomk_indices.cpu()]


            # update
            highest_patches_scores = top_patches_scores
            lowest_patches_scores = bottom_patches_scores
            highest_patches = top_patches
            lowest_patches = bottom_patches
            if return_images:
                highest_images = top_images
                lowest_images = bottom_images

            iteration += 1

    if return_images:
        return bottom_patches, bottom_patches_scores, lowest_images, top_patches, top_patches_scores, highest_images
    else:
        return bottom_patches, bottom_patches_scores, top_patches, top_patches_scores

def get_random_patches(
    dataloader: torch.utils.data.DataLoader,
    patch_size: int = 64,
    stride: int = 16,
    num_patches: int = 4):
    """    
    Args:
        dataloader: PyTorch dataloader containing images
        patch_size: Size of square patches to extract
        stride: Stride of patches
        num_patches: Number of top patches to keep
    
    Returns:
        Returns patches (num_patches, 3, patch_size, patch_size)
    """
    patches = torch.zeros((num_patches, 3, patch_size, patch_size))

    total_samples = len(dataloader.dataset)

    possible_locations_x_y = list(range(0, 224-patch_size, stride))

    img_ids = random.sample(range(0, total_samples), num_patches) # get the ID of an image for each patch I want to sample

    for i,img_id in enumerate(img_ids):
        image = dataloader.dataset[img_id][0]
        patch_x = random.sample(possible_locations_x_y, 1)[0]
        patch_y = random.sample(possible_locations_x_y, 1)[0]
        patch = image[:,patch_y:patch_y+patch_size,patch_x:patch_x+patch_size]
        patches[i,:,:,:] = patch

    return patches

def paste_patches(images, patches):
    B, C, H, W = images.shape
    patch_size = patches.shape[-1]

    # Step 1: Sample N random patch indices (with replacement)
    patch_indices = torch.randint(0, patches.size(0), (B,))

    # Step 2: For each image, choose a random top-left corner where patch fits
    max_x = W - patch_size
    max_y = H - patch_size

    xs = torch.randint(0, 2, (B,)) * (max_x) # random part outputs 0 or 1; then multiply with max vals to make sure patch is placed in corners
    ys = torch.randint(0, 2, (B,)) * (max_y)

    # Step 3: Paste patches into the images
    augmented_images = images.clone()  # avoid modifying original
    #return augmented_images
    for i in range(B):
        patch = patches[patch_indices[i]]
        x, y = xs[i], ys[i]
        augmented_images[i, :, y:y+patch_size, x:x+patch_size] = patch

    return augmented_images




@torch.no_grad()
def get_highest_patches(
    dataloader: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    channel_of_interest: int,
    patch_size: int = 64,
    stride: int = 16,
    num_patches: int = 4,
    return_images: bool = False,
):
    model = model.cuda()
    model.eval()
    
    # List to store (patch, score) tuples
    highest_patches_scores = torch.ones((1, num_patches)).cuda() * (-float('inf')) # 1 x nr_patches
    lowest_patches_scores = torch.ones((1, num_patches)).cuda() * float('inf')

    highest_patches = torch.ones((1, num_patches, 3, patch_size, patch_size)).cuda() # 1 x nr_patches x patch size(3,H,W)
    lowest_patches = torch.ones((1, num_patches, 3, patch_size, patch_size)).cuda() # 1 x nr_patches x patch size(3,H,W)

    if return_images:
        highest_images = torch.ones((1, num_patches, 3, 224, 224)).cuda() # 1 x nr_patches x image size(3,H,W)
        lowest_images = torch.ones((1, num_patches, 3, 224, 224)).cuda() # 1 x nr_patches x image size(3,H,W)


    with torch.no_grad():
        for images, label in dataloader:
            images = images.cuda()
            
            batch_size, channels, height, width = images.shape
            assert batch_size == 1

            kc, kh, kw = 3, patch_size, patch_size  # kernel size
            dc, dh, dw = 1, stride, stride  # stride
            
            patches = images.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
            patches = patches.contiguous().view(-1, kc, kh, kw)
            
             # get model output for this batch of patches
            outputs = model(patches)
            topk_vals, topk_indices = torch.topk(outputs[:,channel_of_interest].unsqueeze(1), num_patches, dim=0)

            topk_vals = topk_vals.permute(1,0) # output_size x num_patches

            topk_indices = topk_indices.permute(1,0) # output_size x num_patches

            topk_patches = patches[topk_indices]
            
            # to see if the current topk are better than global ones, we concatenate them, and find topk again
            tmp_scores_highest = torch.cat([highest_patches_scores, topk_vals], dim=1)
            tmp_patches_highest = torch.cat([highest_patches, topk_patches], dim=1)
            if return_images:
                tmp_images_highest = torch.cat([highest_images, images.unsqueeze(0).repeat(1, num_patches, 1, 1, 1)], dim=1)
            top_patches_scores, topk_indices = torch.topk(tmp_scores_highest, num_patches, dim=1)
              
            top_patches = tmp_patches_highest[0,topk_indices,:,:,:]
            
            if return_images:
                top_images = tmp_images_highest[0,topk_indices,:,:,:]
            

            # update
            highest_patches_scores = top_patches_scores
            highest_patches = top_patches
            if return_images:
                highest_images = top_images
            
    if return_images:
        return top_patches, highest_images, top_patches_scores
    else:
        return top_patches, top_patches_scores
