import torch
from torch import nn
import numpy as np
from einops import rearrange, repeat
from croma_transformer import BaseTransformer
import math
import itertools
import torch.nn.functional as F
from torch import distributed as dist

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)
    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)
    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def get_alibi(attention_heads, num_patches):
    points = list(itertools.product(range(int(math.sqrt(num_patches))), range(int(math.sqrt(num_patches)))))
    def get_slopes(n):
        def get_slopes_power_of_2(n):
            start = (2 ** (-2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio ** i for i in range(n)]
        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n))
            return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
    slopes = torch.Tensor(get_slopes(attention_heads)).unsqueeze(1)
    idxs = []
    for p1 in points:
        for p2 in points:
            dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
            idxs.append(dist * slopes * -1)
    all_bias = torch.cat(idxs, dim=1)
    return all_bias.view(1, attention_heads, num_patches, num_patches)

def get_mask(bsz, seq_len, device, mask_ratio):
    len_keep = int(seq_len * (1 - mask_ratio))
    noise = torch.rand(bsz, seq_len, device=device)  # noise in [0, 1]
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)
    ids_keep = ids_shuffle[:, :len_keep]
    mask = torch.ones([bsz, seq_len], device=device)
    mask[:, :len_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore)
    mask_info = {
        'ids_restore': ids_restore,
        'ids_keep': ids_keep,
        'len_keep': len_keep,
        'mask_for_mae': mask
    }
    return mask_info

def apply_mask_to_sequence(x, ids_keep):
    return torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1]))

def apply_mask_to_alibi(alibi, ids_keep_queries, ids_keep_keys, batch_size, orig_seq_len, masked_seq_len, attention_heads):
    ids_keep_matrix = rearrange(ids_keep_queries, 'b i -> b i 1') + rearrange(ids_keep_keys, 'b i -> b 1 i') * orig_seq_len
    ids_keep_long_sequence = rearrange(ids_keep_matrix, 'b i j -> b (i j)')
    alibi_long_sequence = rearrange(alibi.repeat(batch_size, 1, 1, 1), 'b n i j -> b (i j) n')
    alibi_masked = torch.gather(alibi_long_sequence, dim=1, index=ids_keep_long_sequence.unsqueeze(-1).repeat(1, 1, attention_heads))
    return rearrange(alibi_masked, 'b (i j) n -> b n i j', i=masked_seq_len, j=masked_seq_len)

def gather_features(features, world_size):
    gathered_image_features = [torch.zeros_like(features) for _ in range(world_size)]
    dist.all_gather(gathered_image_features, features)
    all_features = torch.cat(gathered_image_features, dim=0)
    return all_features

class ContrastLossInputMix(nn.Module):
    def __init__(
            self,
            projection_input=768,
            projection_output=768,
    ):
        super().__init__()
        self.s1_proj = nn.Linear(projection_input, projection_output)
        self.s2_proj = nn.Linear(projection_input, projection_output)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, s1_features, s2_features, world_size, rank, lam, mixup_labels):
        s1_features = self.s1_proj(s1_features)
        s2_features = self.s2_proj(s2_features)
        s1_features = s1_features / s1_features.norm(width=1, keepdim=True)
        s2_features = s2_features / s2_features.norm(width=1, keepdim=True)
        all_s1_features = gather_features(features=s1_features, world_size=world_size)
        all_s2_features = gather_features(features=s2_features, world_size=world_size)
        logit_scale = self.logit_scale.exp()
        logits_per_s2 = logit_scale * s2_features @ all_s1_features.t()
        logits_per_s1 = logit_scale * s1_features @ all_s2_features.t()
        num_logits = logits_per_s2.shape[0]
        labels = torch.arange(num_logits, device=s1_features.device, dtype=torch.long)
        labels = labels + num_logits * rank
        mixed_up_labels = mixup_labels + num_logits * rank
        og_loss = (
            F.cross_entropy(logits_per_s2, labels, reduction='none') +
            F.cross_entropy(logits_per_s1, labels, reduction='none')
            ) / 2
        mixed_loss = (
            F.cross_entropy(logits_per_s2, mixed_up_labels, reduction='none') +
            F.cross_entropy(logits_per_s1, mixed_up_labels, reduction='none')
            ) / 2
        total_loss = lam*og_loss + (1-lam)*mixed_loss
        return total_loss.mean()


class ViT(nn.Module):
    def __init__(self,
                 num_patches,
                 width=768,
                 layers=12,
                 attention_heads=16,
                 in_channels=12,
                 patch_size=8,
                 ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.attention_heads = attention_heads
        self.num_patches = num_patches
        self.patch_size = patch_size
        pixels_per_patch = int(patch_size * patch_size * in_channels)
        self.linear_input = nn.Linear(pixels_per_patch, self.width)
        self.transformer = BaseTransformer(width=self.width,
                                           layers=self.layers,
                                           attention_heads=self.attention_heads,
                                           )

    def forward(self, imgs, attn_bias, mask_info=None):
        x = rearrange(imgs, 'b c (h i) (w j) -> b (h w) (c i j)', i=self.patch_size, j=self.patch_size)
        x = self.linear_input(x)
        if mask_info is None:
            x = self.transformer(x, alibi=attn_bias)
            return x
        else:
            x_masked = apply_mask_to_sequence(x=x, ids_keep=mask_info['ids_keep'])
            x_masked = self.transformer(x_masked, alibi=attn_bias)
            return x_masked

class DecoderMAE(nn.Module):
    def __init__(self,
                 num_patches,
                 encoder_width=768,
                 decoder_width=768,
                 decoder_layers=12,
                 attention_heads=16,
                 total_channels=14,
                 patch_size=8,
                 ):
        super().__init__()
        self.decoder_width = decoder_width
        self.decoder_layers = decoder_layers
        self.attention_heads = attention_heads
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.encoder_to_decoder = nn.Linear(encoder_width, self.decoder_width)
        self.decoder = BaseTransformer(width=self.decoder_width,
                                       layers=self.decoder_layers,
                                       attention_heads=self.attention_heads,
                                       )
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.decoder_width), requires_grad=False)
        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(num_patches ** .5), cls_token=False)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
        pixels_per_patch = int(patch_size * patch_size * total_channels)
        self.linear_output = nn.Linear(self.decoder_width, pixels_per_patch)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_width))
        torch.nn.init.normal_(self.mask_token, std=.02)

    def forward(self, x, mask_info_radar, mask_info_optical, target):
        x = self.encoder_to_decoder(x)
        mask_tokens = self.mask_token.repeat(x.shape[0], mask_info_radar['ids_restore'].shape[1] + 1 - x.shape[1], 1)
        x = torch.cat([x, mask_tokens], dim=1)
        x = torch.gather(x, dim=1, index=mask_info_radar['ids_restore'].unsqueeze(-1).repeat(1, 1, x.shape[2]))
        x = x + self.decoder_pos_embed
        x = self.linear_output(self.decoder(x))
        pixel_mask_optical = rearrange(repeat(mask_info_optical['mask_for_mae'], 'b l -> b l c', c=int(self.patch_size*self.patch_size*12)),
                                 'b (h w) (c i j) -> b c (h i) (w j)', c=12, h=int(self.num_patches**0.5), w=int(self.num_patches**0.5), i=self.patch_size, j=self.patch_size)
        pixel_mask_radar = rearrange(repeat(mask_info_radar['mask_for_mae'], 'b l -> b l c', c=int(self.patch_size*self.patch_size*2)),
                               'b (h w) (c i j) -> b c (h i) (w j)', c=2, h=int(self.num_patches**0.5), w=int(self.num_patches**0.5), i=self.patch_size, j=self.patch_size)
        combined_pixel_mask = rearrange(torch.cat([pixel_mask_optical, pixel_mask_radar], dim=1), 'b c (h i) (w j) -> b (h w) (c i j)',
                                  i=self.patch_size, j=self.patch_size)
        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.e-6)**.5
        loss = (x - target) ** 2
        loss = ((loss * combined_pixel_mask).mean(dim=-1)).sum() / mask_info_radar['mask_for_mae'].sum()
        return loss
