# 1D Vision Transformer for time-series

import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange


class DropPath(nn.Module):
    def __init__(self, drop_prob: float, scale_by_keep: bool = True):
        super().__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        if self.drop_prob <= 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
        if keep_prob > 0.0 and self.scale_by_keep:
            random_tensor.div_(keep_prob)
        return x * random_tensor


class PreNorm(nn.Module):
    def __init__(self, dim: int, fn: nn.Module):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, drop_out_rate=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop_out_rate),
            nn.Linear(hidden_dim, output_dim),
            nn.Dropout(drop_out_rate)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, heads: int = 8, dim_head: int = 64,
                 qkv_bias: bool = True, drop_out_rate: float = 0., attn_drop_out_rate: float = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == input_dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(attn_drop_out_rate)
        self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias)

        if project_out:
            self.to_out = nn.Sequential(nn.Linear(inner_dim, output_dim), nn.Dropout(drop_out_rate))
        else:
            self.to_out = nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class TransformerBlock(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, heads: int = 8,
                 dim_head: int = 32, qkv_bias: bool = True, drop_out_rate: float = 0.,
                 attn_drop_out_rate: float = 0., drop_path_rate: float = 0.):
        super().__init__()
        attn = Attention(input_dim, output_dim, heads, dim_head, qkv_bias, drop_out_rate, attn_drop_out_rate)
        self.attn = PreNorm(input_dim, attn)
        self.droppath1 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()

        ff = FeedForward(output_dim, output_dim, hidden_dim, drop_out_rate)
        self.ff = PreNorm(output_dim, ff)
        self.droppath2 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()

    def forward(self, x):
        x = self.droppath1(self.attn(x)) + x
        x = self.droppath2(self.ff(x)) + x
        return x


class ViT(nn.Module):
    def __init__(self,
                 num_leads: int,
                 seq_len: int,
                 patch_size: int,
                 lead_wise=0,
                 patch_size_ch=4,
                 use_lead_embedding: bool = True,
                 width: int = 768,
                 depth: int = 12,
                 mlp_dim: int = 3072,
                 heads: int = 12,
                 dim_head: int = 64,
                 qkv_bias: bool = True,
                 drop_out_rate: float = 0.,
                 attn_drop_out_rate: float = 0.,
                 drop_path_rate: float = 0.,
                 **kwargs):
        super().__init__()
        assert seq_len % patch_size == 0
        num_patches = seq_len // patch_size
        self.lead_wise = lead_wise
        self.use_lead_embedding = use_lead_embedding

        if lead_wise == 0:
            self.to_patch_embedding = nn.Conv1d(num_leads, width, kernel_size=patch_size, stride=patch_size, bias=False)
            self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, width))
        else:
            self.to_patch_embedding = nn.Conv2d(1, width, kernel_size=(patch_size_ch, patch_size),
                                                stride=(patch_size_ch, patch_size), bias=False)
            self.pos_embedding = nn.Parameter(torch.randn(1, num_patches * num_leads // patch_size_ch, width))
            if use_lead_embedding:
                self.lead_emb = nn.Embedding(num_leads // patch_size_ch, width)
            else:
                self.lead_emb = None

        self.dropout = nn.Dropout(drop_out_rate)
        self.depth = depth
        self.width = width

        drop_path_rate_list = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        for i in range(depth):
            block = TransformerBlock(width, width, mlp_dim, heads, dim_head, qkv_bias,
                                     drop_out_rate, attn_drop_out_rate, drop_path_rate_list[i])
            self.add_module(f'block{i}', block)

        self.norm = nn.LayerNorm(width)
        self.head = nn.Identity()

    def forward_encoding(self, series: torch.Tensor, is_patch: bool = False) -> torch.Tensor:
        # series: [B, C, T]
        B, L, T = series.shape

        if self.lead_wise == 0:
            x = self.to_patch_embedding(series)  # [B, D, N]
            x = rearrange(x, 'b c n -> b n c')  # [B, N, D]
            Nt = x.shape[1]
            pe = self.pos_embedding[:, :Nt, :].to(x.device)
            x = x + pe
        else:
            x = series.unsqueeze(1)  # [B, 1, C, T]
            x = self.to_patch_embedding(x)  # [B, D, Lr, Nt]
            Lr, Nt = x.shape[-2], x.shape[-1]
            x = rearrange(x, 'b c lr nt -> b (lr nt) c')  # [B, N, D]
            pe = self.pos_embedding[:, :(Lr * Nt), :].to(x.device)
            x = x + pe
            if self.use_lead_embedding and self.lead_emb is not None:
                row_ids = torch.arange(Lr, device=x.device).repeat_interleave(Nt)
                x = x + self.lead_emb(row_ids)[None, :, :]

        x = self.dropout(x)
        for i in range(self.depth):
            x = getattr(self, f'block{i}')(x)
        if not is_patch:
            x = x.mean(dim=1)
        return self.norm(x)

    def forward(self, series, is_patch=False):
        x = self.forward_encoding(series, is_patch=is_patch)
        return self.head(x)

    def reset_head(self, num_classes=1):
        del self.head
        self.head = nn.Linear(self.width, num_classes)


def vit_nano(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
    return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
               width=128, depth=6, heads=4, mlp_dim=512, **kwargs)


def vit_tiny(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
    return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
               width=192, depth=12, heads=3, mlp_dim=768, **kwargs)


def vit_small(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
    return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
               width=384, depth=12, heads=6, mlp_dim=1536, **kwargs)


def vit_middle(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
    return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
               width=512, depth=12, heads=8, mlp_dim=2048, **kwargs)


def vit_base(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
    return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
               width=768, depth=12, heads=12, mlp_dim=3072, **kwargs)
