from functools import partial

import torch
import torch.nn as nn
from einops import rearrange

from vit import ViT, TransformerBlock


def get_1d_sincos_pos_embed(embed_dim: int,
                            grid_size: int,
                            temperature: float = 10000,
                            sep_embed: bool = False):
    """Positional embedding for 1D patches.
    """
    assert (embed_dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb'
    grid = torch.arange(grid_size, dtype=torch.float32)

    omega = torch.arange(embed_dim // 2, dtype=torch.float32) / (embed_dim / 2.)
    omega = 1. / (temperature ** omega)

    grid = grid.flatten()[:, None] * omega[None, :]
    pos_embed = torch.cat((grid.sin(), grid.cos()), dim=1)
    if sep_embed:
        pos_embed = torch.cat((torch.zeros(1, embed_dim), pos_embed, torch.zeros(1, embed_dim)), dim=0)
    return pos_embed


class ST_MEM(nn.Module):
    """Spatio-Temporal Masked Electrocardiogram Modeling (ST-MEM) model.
    """
    def __init__(self,
                 num_leads: int = 12,
                 seq_len: int = 2250,
                 patch_size: int = 75,
                 embed_dim: int = 768,
                 depth: int = 12,
                 num_heads: int = 12,
                 decoder_embed_dim: int = 256,
                 decoder_depth: int = 4,
                 decoder_num_heads: int = 4,
                 mlp_ratio: int = 4,
                 qkv_bias: bool = True,
                 norm_layer: nn.Module = nn.LayerNorm,
                 norm_pix_loss: bool = False):
        super().__init__()
        self.num_leads = num_leads
        self.patch_size = patch_size
        self.num_patches = seq_len // patch_size
        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.encoder = ViT(num_leads=num_leads,
                           seq_len=seq_len,
                           patch_size=patch_size,
                           width=embed_dim,
                           depth=depth,
                           mlp_dim=mlp_ratio * embed_dim,
                           heads=num_heads,
                           qkv_bias=qkv_bias)
        self.patch_embed = self.encoder.to_patch_embedding
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_embedding = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        # fixed sin-cos embedding
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 2, decoder_embed_dim),
                                              requires_grad=False)

        self.decoder_blocks = nn.ModuleList([TransformerBlock(input_dim=decoder_embed_dim,
                                                              output_dim=decoder_embed_dim,
                                                              hidden_dim=decoder_embed_dim * mlp_ratio,
                                                              heads=decoder_num_heads,
                                                              dim_head=64,
                                                              qkv_bias=qkv_bias)
                                             for _ in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim,
                                      patch_size,
                                      bias=True)  # decoder to patch
        # --------------------------------------------------------------------------
        self.norm_pix_loss = norm_pix_loss
        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_1d_sincos_pos_embed(self.encoder.pos_embedding.shape[-1],
                                            self.num_patches,
                                            sep_embed=True)
        self.encoder.pos_embedding.data.copy_(pos_embed.float().unsqueeze(0))
        self.encoder.pos_embedding.requires_grad = False

        decoder_pos_embed = get_1d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
                                                    self.num_patches,
                                                    sep_embed=True)
        self.decoder_pos_embed.data.copy_(decoder_pos_embed.float().unsqueeze(0))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.encoder.sep_embedding, std=.02)
        torch.nn.init.normal_(self.mask_embedding, std=.02)
        for i in range(self.num_leads):
            torch.nn.init.normal_(self.encoder.lead_embeddings[i], std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, series):
        """
        series: (B, L, T)
        x: (B, L*n, patch_size)
        """
        p = self.patch_size
        assert series.shape[2] % p == 0
        x = rearrange(series, 'b (g c) (n p) -> b (g n) (c p)',
                      g=self.num_leads,
                      p=p)
        return x

    def unpatchify(self, x):
        """
        x: (B, L*n, patch_size)
        series: (B, L, T)
        """
        p = self.patch_size
        series = rearrange(x, 'b (g n) (c p) -> b (g c) (n p)',
                           g=self.num_leads,
                           p=p)
        return series

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [B, n, D], sequence
        """
        B, n, D = x.shape  # batch, length, dim
        len_keep = int(n * (1 - mask_ratio))

        noise = torch.rand(B, n, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, n], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        """
        x => batch_size x num_leads x seq_len
        """
        # embed patches
        all_leads = []
        all_masks = []
        all_ids_restore = []

        for i in range(self.num_leads):

            x_lead = self.patch_embed(x[:, i:i + 1, :])
            B, n, _ = x_lead.shape

            # add positional embeddings
            x_lead = x_lead + self.encoder.pos_embedding[:, 1:n + 1, :]

            # masking: length -> length * mask_ratio
            if mask_ratio > 0:
                x_lead, mask, ids_restore = self.random_masking(x_lead, mask_ratio)
                ids_restore = ids_restore + (n * i)

            # lead indicating modules
            front_sep = self.encoder.sep_embedding + self.encoder.pos_embedding[:, :1, :]
            front_sep = front_sep.expand(B, -1, -1)
            back_sep = self.encoder.sep_embedding + self.encoder.pos_embedding[:, -1:, :]
            back_sep = back_sep.expand(B, -1, -1)
            x_lead = torch.cat((front_sep, x_lead, back_sep), dim=1)
            lead_embedding = self.encoder.lead_embeddings[i].unsqueeze(0)
            lead_embedding = lead_embedding.expand(x_lead.shape[0], x_lead.shape[1], -1)
            x_lead += lead_embedding

            all_leads.append(x_lead)
            all_masks.append(mask)
            all_ids_restore.append(ids_restore)

        all_leads = torch.cat(all_leads, dim=1)
        all_masks = torch.cat(all_masks, dim=1)
        all_ids_restore = torch.cat(all_ids_restore, dim=1)

        # apply Transformer blocks
        for i in range(self.encoder.depth):
            all_leads = getattr(self.encoder, f'block{i}')(all_leads)
        all_leads = self.encoder.norm(all_leads)

        return all_leads, all_masks, all_ids_restore

    def forward_decoder(self, x, ids_restore):

        x = self.decoder_embed(x)

        # append mask embeddings to sequence
        each_lead_patch_len = ids_restore.shape[1] // self.num_leads
        each_lead_masked_patch_len = x.shape[1] // self.num_leads
        mask_embeddings = self.mask_embedding.repeat(x.shape[0], each_lead_patch_len + 2 - each_lead_masked_patch_len, 1)
        x_ = []
        for i in range(self.num_leads):
            from_ = i * each_lead_masked_patch_len
            end_ = from_ + each_lead_masked_patch_len
            x_.append(torch.cat([x[:, from_ + 1:end_ - 1, :], mask_embeddings], dim=1))  # unshuffle without sep embedding
        x_ = torch.cat(x_, dim=1)  # 30
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))

        append_sep_embed = []
        for i in range(self.num_leads):
            from_x = i * each_lead_patch_len
            end_x = from_x + each_lead_patch_len
            from_y = i * each_lead_masked_patch_len
            end_y = from_y + each_lead_masked_patch_len

            # add pos embed
            without_sep_x = x_[:, from_x:end_x, :] + self.decoder_pos_embed[:, 1:each_lead_patch_len + 1, :]
            front_sep = x[:, from_y:from_y + 1, :] + self.decoder_pos_embed[:, :1, :]
            back_sep = x[:, end_y - 1:end_y, :] + self.decoder_pos_embed[:, -1:, :]

            with_sep_x = torch.cat([front_sep, without_sep_x, back_sep], dim=1)

            append_sep_embed.append(with_sep_x)

        append_sep_embed = torch.cat(append_sep_embed, dim=1)

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            append_sep_embed = blk(append_sep_embed)
        append_sep_embed = self.decoder_norm(append_sep_embed)

        # predictor projection
        append_sep_embed = self.decoder_pred(append_sep_embed)

        # remove sep embedding
        each_lead_patch_len = append_sep_embed.shape[1] // self.num_leads  # 32
        x = []
        for i in range(self.num_leads):
            from_ = i * each_lead_patch_len
            end_ = from_ + each_lead_patch_len
            x.append(append_sep_embed[:, from_ + 1:end_ - 1, :])
        x = torch.cat(x, dim=1)

        return x

    def forward_leadwise_decoder(self, x, ids_restore):

        x = self.decoder_embed(x)

        # append mask embeddings to sequence
        each_lead_patch_len = ids_restore.shape[1] // self.num_leads
        each_lead_masked_patch_len = x.shape[1] // self.num_leads
        mask_embeddings = self.mask_embedding.repeat(x.shape[0], each_lead_patch_len + 2 - each_lead_masked_patch_len, 1)
        x_ = []
        for i in range(self.num_leads):
            from_ = i * each_lead_masked_patch_len
            end_ = from_ + each_lead_masked_patch_len
            x_.append(torch.cat([x[:, from_ + 1:end_ - 1, :], mask_embeddings], dim=1))
        x_ = torch.cat(x_, dim=1)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle

        decoded = []
        for i in range(self.num_leads):
            from_x = i * each_lead_patch_len
            end_x = from_x + each_lead_patch_len
            from_y = i * each_lead_masked_patch_len
            end_y = from_y + each_lead_masked_patch_len

            # add pos embed
            without_sep_x = x_[:, from_x:end_x, :] + self.decoder_pos_embed[:, 1:each_lead_patch_len + 1, :]
            front_sep = x[:, from_y:from_y + 1, :] + self.decoder_pos_embed[:, :1, :]
            back_sep = x[:, end_y - 1:end_y, :] + self.decoder_pos_embed[:, -1:, :]

            with_sep_x = torch.cat([front_sep, without_sep_x, back_sep], dim=1)

            # apply Transformer blocks
            for blk in self.decoder_blocks:
                with_sep_x = blk(with_sep_x)
            with_sep_x = self.decoder_norm(with_sep_x)

            # predictor projection
            with_sep_x = self.decoder_pred(with_sep_x)

            without_sep_x = with_sep_x[:, 1:-1, :]

            decoded.append(without_sep_x)

        x = torch.cat(decoded, dim=1)

        return x

    def forward_loss(self, series, pred, mask):
        """
        series: [B, C, N]
        pred: [B, L, p*C]
        mask: [B, L], 0 is keep, 1 is remove,
        """
        target = self.patchify(series)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

    def forward(self,
                series,
                mask_ratio=0.75,
                decoding='leadwise'):
        recon_loss = 0
        pred = None
        mask = None

        latent, mask, ids_restore = self.forward_encoder(series, mask_ratio)
        if decoding == 'leadwise':
            pred = self.forward_leadwise_decoder(latent, ids_restore)  # [N, L, p*C]
        elif decoding == 'all':
            pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*C]
        recon_loss = self.forward_loss(series, pred, mask)

        return recon_loss, pred, mask


def st_mem_vit_base_dec256d4b(**kwargs):
    model = ST_MEM(embed_dim=768,
                   depth=12,
                   num_heads=12,
                   decoder_embed_dim=256,
                   decoder_depth=4,
                   decoder_num_heads=4,
                   mlp_ratio=4,
                   norm_layer=partial(nn.LayerNorm, eps=1e-6),
                   **kwargs)
    return model


# set recommended archs (decoder: 256 dim, 4 blocks)
st_mem_vit_base = st_mem_vit_base_dec256d4b
