from typing import Dict, List, Union
import torch
from torch.distributions.dirichlet import Dirichlet

def split_region(ratios: torch.Tensor, total_regions: torch.Tensor) -> torch.Tensor:
    """
    Distribute the number of masks for each region.
    Args:
        ratios: A tensor of shape [B, 133], with each row summing to 1.
        total_regions: A tensor of shape [B], representing the total number of masks for each group.
    Returns:
        The distribution result, a tensor of shape [B, 133].
    """
    # Initial distribution
    total_regions_expanded = total_regions.unsqueeze(-1)  # [B, 1]
    split_result = torch.round(ratios * total_regions_expanded)  # [B, 133]
    
    # Calculate the difference between the actual sum and the target sum for each group
    actual_sums = split_result.sum(dim=1)  # [B]
    differences = total_regions - actual_sums  # [B]
    
    # Find the position of the maximum value among non-zero ratios for each group
    # Replace 0s with a small negative number so they are not selected by argmax
    masked_ratios = torch.where(ratios > 0, ratios, -0.1)
    max_ratio_indices = masked_ratios.argmax(dim=1)  # [B]
    
    # Create the adjustment matrix
    max_ratio_mask = torch.nn.functional.one_hot(max_ratio_indices, num_classes=ratios.size(1))  # [B, 133]
    adjustment = torch.where(
        max_ratio_mask > 0, 
        differences.unsqueeze(-1),  # [B, 1] -> broadcast to [B, 133]
        torch.zeros_like(split_result)
    )
    
    # Apply the adjustment
    split_result = split_result + adjustment
    split_result = torch.clamp(split_result, min=0)
    
    return split_result

def generate_random_masks(input_tokens: Dict[str, torch.Tensor],
                          num_encoded_tokens: int,
                          alphas: Union[float, List[float]] = 1.0,
                          insseg_mask = None,
                          semseg_mask = None,
                          curr_epoch = 0,
                          total_epoch = 1600):
    """
    Sample a total of num_encoded_tokens from different tasks using Dirichlet sampling.

    :param input_tokens: Dictionary of tensors to sample num_encoded_tokens from
    :param num_encoded_tokens: Number of tokens to select
    :param alphas: Dirichlet distribution parameter alpha. Lower alpha = harder,
        less uniform sampling. Can be float or list of floats.
    :param insseg_mask: Instance segmentation masks for the images
    :param semseg_mask: Semantic segmentation masks for the images
    """
    B = list(input_tokens.values())[0].shape[0]
    device = list(input_tokens.values())[0].device
    
    if curr_epoch <= total_epoch / 5:
        raw_alpha_sem = 1 
    elif curr_epoch <= (2*total_epoch / 5) and curr_epoch > total_epoch / 5:
        progress_in_interval = (curr_epoch - total_epoch / 5)
        interval_duration = (total_epoch / 5)
        raw_alpha_sem = 1.0 - (progress_in_interval / interval_duration)
    else:
        raw_alpha_sem = 0

    if curr_epoch <= total_epoch / 5:
        raw_alpha_ins = 0
    elif curr_epoch <= (2*total_epoch / 5) and curr_epoch > total_epoch / 5:
        progress_in_interval = curr_epoch - (total_epoch / 5)
        interval_duration = total_epoch / 5
        raw_alpha_ins = progress_in_interval / interval_duration
    elif curr_epoch <= (3*total_epoch / 5) and curr_epoch > (2*total_epoch / 5):
        raw_alpha_ins = 1.0
    elif curr_epoch <= (4*total_epoch / 5) and curr_epoch > (3*total_epoch / 5):
        progress_in_interval = curr_epoch - (3 * total_epoch / 5)
        interval_duration = total_epoch / 5
        raw_alpha_ins = 1.0 - (progress_in_interval / interval_duration)
    else:
        raw_alpha_ins = 0

    alpha_sem = raw_alpha_sem
    alpha_ins = raw_alpha_ins
        
    alphas = [alphas] * len(input_tokens) if isinstance(alphas, float) else alphas
    task_sampling_dist = Dirichlet(torch.Tensor(alphas)).sample((B,)).to(device)

    samples_per_task = (task_sampling_dist * num_encoded_tokens).round().long()

    task_masks = []
    num_tokens_per_task = [task_tokens.shape[1] for task_tokens in input_tokens.values()]
    if insseg_mask is not None:
        instance_mask = torch.where(insseg_mask > 0, 1, 0)
        instance_mask = torch.sum(instance_mask, dim=-1)
        instance_mask = (instance_mask - instance_mask.min(dim=1, keepdim=True)[0]) / (instance_mask.max(dim=1, keepdim=True)[0] - instance_mask.min(dim=1, keepdim=True)[0]+1e-8)
    if semseg_mask is not None:
        semantic_mask = torch.mode(semseg_mask, dim=-1)[0]
        semantic_mask = torch.nn.functional.one_hot(semantic_mask,num_classes=133) # semantic_mask (N, L, 133)
    for i, num_tokens in enumerate(num_tokens_per_task):
        
        if insseg_mask is not None and alpha_ins != 0:
            # Combine noise and instance mask
            noise = torch.rand(B, num_tokens, device=device)
            combined_mask = 0.1 * noise + instance_mask
            ids_arange_shuffle = torch.argsort(combined_mask, dim=1)  # ascend: small is keep, large is remove # weighted noise with insseg mask
            ids_restore = torch.argsort(ids_arange_shuffle, dim=1)
            insguided_mask = torch.arange(num_tokens, device=device).unsqueeze(0).expand(B, -1)
            insguided_mask = torch.gather(insguided_mask, dim=1, index=ids_arange_shuffle)
            # 0 is keep (unmasked), 1 is remove (masked)
            insguided_mask = torch.where(insguided_mask < samples_per_task[:, i].unsqueeze(1), 0, 1)
        else:
            insguided_mask = 0

        if semseg_mask is not None and alpha_sem != 0: # semseg_mask (N, L, patch_size**2)
            samples_per_region = split_region(
                semantic_mask.sum(1).float()/num_tokens, 
                samples_per_task[:, i]
            )
            semseg_noise = torch.rand(B, num_tokens, 133, device=device)
            semseg_noise = (semseg_noise + 0.1) * semantic_mask # +0.1 prevents random zeros from blending with pixels outside the target region
            ids_arange_shuffle = torch.argsort(semseg_noise, dim=1, descending=True)
            ids_arange_restore = torch.argsort(ids_arange_shuffle, dim=1)
            semguided_mask = torch.arange(num_tokens, device=device).unsqueeze(0).unsqueeze(2).expand(B, -1, 133)
            semguided_mask = torch.where(semguided_mask < samples_per_region.unsqueeze(1), 1, 0)  # This is inverted compared to others for easier processing, will be reversed later
            semguided_mask = torch.gather(semguided_mask, dim=1, index=ids_arange_restore)
            semguided_mask = (1 - semguided_mask.sum(-1) > 0).to(torch.int)   # 0 is keep (unmasked), 1 is remove (masked)
        else:
            semguided_mask = 0

        noise = torch.rand(B, num_tokens, device=device)  # noise in [0, 1]
        combined_noise = (1-alpha_ins-alpha_sem)*noise + alpha_ins * insguided_mask + alpha_sem * semguided_mask
        ids_arange_shuffle = torch.argsort(combined_noise, dim=1)  # ascend: small is keep, large is remove
        ids_arange_restore = torch.argsort(ids_arange_shuffle, dim=1)
        # 0 is keep (unmasked), 1 is remove (masked)
        mask = torch.arange(num_tokens, device=device).unsqueeze(0).expand(B, -1)
        mask = torch.where(mask < samples_per_task[:, i].unsqueeze(1), 0, 1)

        mask = torch.gather(mask, dim=1, index=ids_arange_restore)
        task_masks.append(mask)

    mask_all = torch.cat(task_masks, dim=1)
    ids_shuffle = torch.argsort(mask_all + torch.rand_like(mask_all.float()), dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)
    ids_keep = ids_shuffle[:, :num_encoded_tokens]

    # Update binary mask to adjust for task rounding
    mask_all = torch.ones_like(mask_all)
    mask_all[:, :num_encoded_tokens] = 0
    # Unshuffle to get the binary mask
    mask_all = torch.gather(mask_all, dim=1, index=ids_restore)
    # Split to get task masks
    task_masks = torch.split(mask_all, num_tokens_per_task, dim=1)
    # Convert to dict
    task_masks = {domain: mask for domain, mask in zip(input_tokens.keys(), task_masks)}

    return task_masks, ids_keep, ids_restore