import math
import random

from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
from copy import deepcopy

import torch
from torch import nn
import torch.nn.functional as F

from .utils import cosine_scheduler

class RandomlySelectedCrossAttentionMasking:
    """Masks based on cross attention"""
    def __init__(
        self,
        masking_ratio=0.5,
        num_per_query=576,
        replace=None,
        exclude_seen_reconstruction=True,
        varying_length=True,
        mask_latents=False,
        relative_noise=0.0,
        select_initial_ratio=None, 
        **kwargs,
    ):
        self.masking_ratio = masking_ratio
        self.num_per_query = num_per_query
        self.replace = replace
        self.exclude_seen_reconstruction = exclude_seen_reconstruction
        self.varying_length = varying_length
        self.mask_latents = mask_latents
        self.relative_noise = relative_noise
        # Will be set by callback
        self.noise_schedule_scale = 0.0
        self.select_initial_ratio = select_initial_ratio
    
    def set_noise_scale(
        self,
        noise_scale
    ):
        self.noise_schedule_scale = noise_scale
    
    def __call__(
        self,
        attentions,
        inputs,
        attention_mask,
        loss_mask: Optional[bool] = None,
        **kwargs
    ):
        batch_size, num_heads, num_queries, seq_len = attentions.shape
        if self.varying_length:
            num_queries_to_use = round(seq_len * self.masking_ratio / self.num_per_query)
            unmasked_bool = ((attention_mask == 1) & (inputs != self.replace))
            num_queries_to_use = unmasked_bool.sum(-1) * self.masking_ratio / self.num_per_query
        else:
            num_queries_to_use = round(seq_len * self.masking_ratio / self.num_per_query)

        if self.relative_noise > 0:
            attentions = attentions + torch.randn(attentions.shape, dtype=attentions.dtype, device=attentions.device) * attentions.std(dim=(1, 2), keepdim=True) * self.relative_noise * self.noise_schedule_scale
        attentions = nn.Softmax(dim=-1)(attentions)
        averaged_attentions = attentions.mean(dim=1)

        if self.varying_length:
            chosen_queries = [torch.randperm(num_queries, device=attentions.device)[:num_queries_to_use[i].round().long()] for i in range(batch_size)]
            chosen_attentions = [averaged_attentions[i, chosen_queries[i]] for i in range(batch_size)]
            summed_attentions = [i.sum(dim=0) for i in chosen_attentions]
            masked_indices = [summed_attentions[i].topk(k=(self.masking_ratio * unmasked_bool[i].sum()).round().long()).indices for i in range(batch_size)]
            batch_indices = torch.cat([torch.tensor(i).repeat(len(masked_indices[i])) for i in range(batch_size)])
            masked_indices = (batch_indices, torch.cat(masked_indices))
        else:
            chosen_queries = [torch.randperm(num_queries, device=attentions.device)[:num_queries_to_use] for i in range(batch_size)]
            chosen_attentions = torch.stack([averaged_attentions[i, chosen_queries[i]] for i in range(batch_size)])
            summed_attentions = chosen_attentions.sum(dim=1)
            if self.select_initial_ratio is not None and self.noise_schedule_scale > 0.0:
                select_ratio = self.masking_ratio + (self.select_initial_ratio - self.masking_ratio) * self.noise_schedule_scale
                masked_indices = summed_attentions.topk(k=round(select_ratio * seq_len), dim=-1).indices
                masked_indices = [i[torch.randperm(len(i))] for i in masked_indices]
                masked_indices = torch.stack([i[:round(self.masking_ratio * seq_len)] for i in masked_indices])
            else:
                masked_indices = summed_attentions.topk(k=round(self.masking_ratio * seq_len), dim=-1).indices
            masked_indices = (torch.repeat_interleave(torch.arange(masked_indices.shape[0]), masked_indices.shape[1]), masked_indices.flatten())
                
        
        if loss_mask is not None:
            if self.exclude_seen_reconstruction:
                loss_mask[masked_indices] = 1
            else:
                loss_mask[attention_mask.nonzero(as_tuple=True)] = 1

        attention_mask[masked_indices] = 0
        
        latent_mask = None
        if self.mask_latents:
            latent_mask = torch.ones((batch_size, 1, 1, num_queries))
            for i in range(batch_size):
                latent_mask[i, :, :, chosen_queries[i]] = 0
        
        return dict(
            attentions=attentions,
            inputs=inputs,
            attention_mask=attention_mask,
            loss_mask=loss_mask,
            latent_mask=latent_mask,
            **kwargs,
        )

class DelimiterBasedRandomMasking:
    "Masks based on some delimiter. Used for baselines. Defaults to space"
    def __init__(
        self,
        masking_ratio=0.5,
        delimiter=38,
        replace=None,
        exclude_seen_reconstruction=False,
        mask_delimiter=False,
        **kwargs,
    ):
        self.masking_ratio = masking_ratio
        self.delimiter = delimiter
        self.replace = replace
        self.exclude_seen_reconstruction = exclude_seen_reconstruction
        self.mask_delimiter = 1 if mask_delimiter else 0

    def __call__(
        self,
        inputs,
        attention_mask,
        loss_mask: Optional[bool] = None,
        **kwargs
    ):
        batch_size, seq_len = inputs.shape
        for batch_idx in range(batch_size):
            delimiter_indices = [i for i in range(seq_len) if inputs[batch_idx][i] == self.delimiter]
            if delimiter_indices[0] != 0:
                delimiter_indices = [0, ] + delimiter_indices
            
            seq_end = (inputs[batch_idx] > 0).sum()
            if delimiter_indices[-1] != seq_end:
                delimiter_indices.append(seq_end)
            num_to_mask = round(delimiter_indices.shape[0] * self.masking_ratio)
            if num_to_mask > 0:
                chunks_to_mask = torch.randperm(len(delimiter_indices) - 1)[:num_to_mask]
                masked_indices = torch.cat([torch.arange(delimiter_indices[i] + 1 - self.mask_delimiter, delimiter_indices[i + 1]) for i in chunks_to_mask])
                if not loss_mask is None:
                    if self.exclude_seen_reconstruction:
                        loss_mask[batch_idx][masked_indices] = 1
                    else:
                        loss_mask[batch_idx][attention_mask[batch_idx].nonzero()] = 1
                
                if self.replace:
                    inputs[batch_idx][masked_indices] = self.replace
                else:
                    attention_mask[batch_idx][masked_indices] = 0
            else:
                print("Not enough delimiters to mask")
        return dict(
            inputs=inputs,
            attention_mask=attention_mask,
            loss_mask=loss_mask,
            **kwargs,
        )
    

class PatchBasedMasking:
    "Masks Based on Patches such as for images"
    def __init__(
        self,
        masking_ratio=0.75,
        patch_size=(32, 32),
        original_image_shape=(224, 224),
        replace=None,
        **kwargs,
    ):
        self.masking_ratio = masking_ratio
        self.patch_size = patch_size
        self.original_image_shape = original_image_shape
        self.replace = replace

    def __call__(
        self,
        inputs,
        attention_mask,
        loss_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        inputs = inputs.view(inputs.shape[0], self.original_image_shape[0], self.original_image_shape[1], inputs.shape[-1])
        attention_mask = attention_mask.view(inputs.shape[0], self.original_image_shape[0], self.original_image_shape[1])
        loss_mask = loss_mask.view(inputs.shape[0], self.original_image_shape[0], self.original_image_shape[1])
        
        batch_size, height, width, channels = inputs.shape
        n_patches = (math.ceil(height / self.patch_size[0]), math.ceil(width / self.patch_size[1]))
        patch_top_lefts = [(i * self.patch_size[0], j * self.patch_size[1]) for i in range(n_patches[0]) for j in range(n_patches[1])]
        num_to_mask = round(n_patches[0] * n_patches[1] * self.masking_ratio)
        for batch_idx in range(batch_size):
            sampled_patches = random.sample(patch_top_lefts, k=num_to_mask)
            for patch in sampled_patches:
                if loss_mask is not None:
                    loss_mask[batch_idx, patch[0]:patch[0] + self.patch_size[0], patch[1]:patch[1] + self.patch_size[1]] = 1
                
                if self.replace is not None:
                    inputs[batch_idx, patch[0]:patch[0] + self.patch_size[0], patch[1]:patch[1] + self.patch_size[1]] = self.replace
                else:
                    attention_mask[batch_idx, patch[0]:patch[0] + self.patch_size[0], patch[1]:patch[1] + self.patch_size[1]] = 0
        
        return dict(
            inputs=inputs.view(inputs.shape[0], -1, inputs.shape[-1]),
            attention_mask=attention_mask.view(attention_mask.shape[0], -1),
            loss_mask=loss_mask.view(loss_mask.shape[0], -1),
            **kwargs,
        )
        
class RandomMasking:
    """Randomly masks"""

    def __init__(
        self,
        masking_ratio=0.5,
        replace=None,
        exclude_seen_reconstruction=False,
        varying_length=False,
        **kwargs,
    ):
        self.masking_ratio = masking_ratio
        self.replace = replace
        self.exclude_seen_reconstruction = exclude_seen_reconstruction
        self.varying_length = varying_length
    def __call__(
        self,
        inputs: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        loss_mask: Optional[torch.Tensor] = None,
        **kwargs
    ):
        batch_size, seq_len, _ = inputs.shape
        if self.varying_length:
            unmasked_bool = ((attention_mask == 1) & (inputs != self.replace))
            for batch_idx in range(batch_size):
                unmasked_indices = unmasked_bool[batch_idx].nonzero().squeeze()
                masked_indices = unmasked_indices[torch.randperm(unmasked_indices.shape[0])[:round(unmasked_indices.shape[0] * self.masking_ratio)]]
                
                if loss_mask is not None:
                    if self.exclude_seen_reconstruction:
                        loss_mask[batch_idx][masked_indices] = 1
                    else:
                        loss_mask[batch_idx][attention_mask[batch_idx].nonzero()] = 1
                
                if self.replace:
                    inputs[batch_idx][masked_indices] = self.replace
                else:
                    attention_mask[batch_idx][masked_indices] = 0
        else:
            masked_indices = torch.stack([torch.randperm(seq_len)[:round(seq_len * self.masking_ratio)] for i in range(batch_size)])
            masked_indices = torch.repeat_interleave(torch.arange(masked_indices.shape[0]), masked_indices.shape[1]), masked_indices.flatten()
        
            if loss_mask is not None:
                if self.exclude_seen_reconstruction:
                    loss_mask[masked_indices] = 1
                else:
                    loss_mask[attention_mask.nonzero(as_tuple=True)] = 1

            attention_mask[masked_indices] = 0

        return dict(
            inputs=inputs,
            attention_mask=attention_mask,
            loss_mask=loss_mask,
            **kwargs,
        )

class IdentityTransform:
    def __init__(self, **kwargs):
        pass
    
    def __call__(
        self,
        **kwargs,
    ):
        if kwargs.get("loss_mask") is None:
            kwargs["loss_mask"] = kwargs.get("attention_mask")
        return dict(
            **kwargs
        )


class TransformCompose(nn.Module):
    def __init__(self, transforms: List[Callable]) -> None:
        super().__init__()
        if not isinstance(transforms, list):
            raise TypeError("Argument transforms should be a list of callables")
        self.transforms = transforms

    def forward(self, **kwargs: Any) -> Any:
        for transform in self.transforms:
            kwargs = transform(**kwargs)
        return kwargs

    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)


TRANSFORMS2CLS = {
    "RandomMasking": RandomMasking,
    "RandomlySelectedCrossAttentionMasking": RandomlySelectedCrossAttentionMasking,
    "DelimiterBasedRandomMasking": DelimiterBasedRandomMasking,
    "PatchBasedMasking": PatchBasedMasking
}

def create_transforms(transform_args):
    return TransformCompose([TRANSFORMS2CLS[c](**args) for c, args in transform_args]) if transform_args else IdentityTransform()
