#!/usr/bin/env python

# Any copyright is dedicated to the Public Domain.
# https://creativecommons.org/publicdomain/zero/1.0/


import math

import torch

from torch import nn
from torch.nn import functional as F

##############################

class Residual(nn.Module):
    def __init__(self, *f):
        super().__init__()
        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)

    def forward(self, x):
        return x + self.f(x)

##############################

class PositionalEncoding(nn.Module):
    def __init__(self, len_max):
        super().__init__()
        self.len_max = len_max

    # From Vaswani et al 2018
    # PE_{t,2i}   = sin(t/(L^{2i/D}))
    # PE_{t,2i+1} = cos(t/(L^{2i/D}))
    def forward(self, x):
        t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None]
        j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :]
        k = j%2
        return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :]

##############################

class QKVAttention(nn.Module):
    def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, cross_attn=False, self_attn=False, attention_dropout = 0.0):
        super().__init__()

        def randw(*d):
            return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1])))

        self.wq = randw(nb_heads, dim_qk, dim_in)
        self.wk = randw(nb_heads, dim_qk, dim_in)
        self.wv = randw(nb_heads, dim_v, dim_in)
        self.causal = causal
        self.cross_attn = cross_attn
        self.self_attn = self_attn
        self.attention_dropout = attention_dropout

    def forward(self, x, return_attn_prenorm=False):
        q = torch.einsum('ntc,hdc->nhtd', x, self.wq)
        k = torch.einsum('ntc,hdc->nhtd', x, self.wk)
        v = torch.einsum('ntc,hdc->nhtd', x, self.wv)
        r = math.sqrt(q.size(3))
        a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r)

        if self.causal:
            mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0
            a = a.masked_fill(mask, float('-inf'))
        if self.cross_attn:
            mask = q.new_zeros(a.size(2), a.size(3)) == 1
            mask[:a.size(2)//2, :a.size(3)//2] = 1
            mask[a.size(2)//2:, a.size(3)//2:] = 1
            mask = mask[None, None, :, :]
            a = a.masked_fill(mask, float('-inf'))
        if self.self_attn:
            mask = q.new_zeros(a.size(2), a.size(3)) == 1
            mask[:a.size(2)//2, a.size(3)//2:] = 1
            mask[a.size(2)//2:, :a.size(3)//2] = 1
            mask = mask[None, None, :, :]
            a = a.masked_fill(mask, float('-inf'))

        a_probs = a.softmax(dim = 3)
        a_probs = F.dropout(a_probs, self.attention_dropout, self.training)
        y = torch.einsum('nhts,nhsd->nhtd', a_probs, v)

        if not return_attn_prenorm:
            return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd)
        else:
            return y.permute(0, 2, 1, 3).flatten(2), a

##############################

class CloudGPT(nn.Module):
    def __init__(self,
                 dim_in, dim_out,
                 dim_model, dim_keys, dim_hidden,
                 nb_heads, nb_blocks, dropout = 0.):

        super().__init__()

        assert dim_model % nb_heads == 0

        self.embedding = nn.Sequential(
            nn.Linear(in_features = dim_in, out_features = dim_model),
            nn.Dropout(dropout),
            # PositionalEncoding(len_max = 1e5),
        )

        trunk_blocks = [ ]

        for _ in range(nb_blocks):
            trunk_blocks += [
                Residual(
                    nn.LayerNorm(dim_model),
                    QKVAttention(
                        dim_in = dim_model,
                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
                        nb_heads = nb_heads,
                        causal = False, attention_dropout = dropout
                    ),
                    nn.Linear(in_features = dim_model, out_features = dim_model),
                ),
                Residual(
                    nn.LayerNorm(dim_model),
                    nn.Linear(in_features = dim_model, out_features = dim_hidden),
                    nn.ReLU(),
                    nn.Linear(in_features = dim_hidden, out_features = dim_model),
                    nn.Dropout(dropout),
                ),
            ]

        self.trunk = nn.ModuleList(trunk_blocks)

        self.readout = nn.Linear(in_features = dim_model, out_features = dim_out)

    def forward(self, x: torch.Tensor):
        """_summary_

        Args:
            x (torch.Tensor): _description_ shape (B, N, d)
            beta (_type_, optional): _description_. Defaults to None. shape (B,1)
            context (_type_, optional): _description_. Defaults to None.

        Returns:
            _type_: _description_
        """        ''''''
        
        # time_emb = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1)[:, None, :]  # (B, 3)
        # time_emb = time_emb.repeat(1, x.shape[1], 1)
        # x = self.embedding(torch.cat((x, time_emb), dim=-1)) 
        x = self.embedding(x) 

        for layer in self.trunk:
            x = layer(x)

        x = self.readout(x)
        return x

######################################################################

class MatchingGPT(nn.Module):
    def __init__(self,
                 dim_in, dim_out,
                 dim_model, dim_keys, dim_hidden,
                 nb_heads, nb_blocks, dropout = 0.):

        super().__init__()

        assert dim_model % nb_heads == 0

        self.embedding = nn.Sequential(
            nn.Linear(in_features = dim_in, out_features = dim_model),
            nn.Dropout(dropout),
            # PositionalEncoding(len_max = 1e5),
        )

        self.pos_embedding = nn.Sequential(
            nn.Linear(in_features = 1, out_features = dim_model),
            nn.Dropout(dropout),
            # PositionalEncoding(len_max = 1e5),
        )

        trunk_blocks = [ ]

        # self_attn = [True, False]
        # cross_attn = [False, True]

        self_attn = [False, False]
        cross_attn = [False, False]

        for idx in range(nb_blocks):
            trunk_blocks += [
                Residual(
                    nn.LayerNorm(dim_model),
                    QKVAttention(
                        dim_in = dim_model,
                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
                        nb_heads = nb_heads,
                        causal = False, cross_attn=cross_attn[idx%2], self_attn=self_attn[idx%2], attention_dropout = dropout
                    ),
                    nn.Linear(in_features = dim_model, out_features = dim_model),
                ),
                Residual(
                    nn.LayerNorm(dim_model),
                    nn.Linear(in_features = dim_model, out_features = dim_hidden),
                    nn.ReLU(),
                    nn.Linear(in_features = dim_hidden, out_features = dim_model),
                    nn.Dropout(dropout),
                ),
            ]

        trunk_blocks += [
                    nn.LayerNorm(dim_model),
                    QKVAttention(
                        dim_in = dim_model,
                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
                        nb_heads = 1,
                        causal = False, cross_attn=False, self_attn=False, attention_dropout = dropout
                    )
            ]

        self.trunk = nn.ModuleList(trunk_blocks)

        self.readout = nn.Linear(in_features = dim_model, out_features = dim_out)

    def forward(self, x: torch.Tensor, pos_emb: torch.Tensor):
        """_summary_

        Args:
            x (torch.Tensor): _description_ shape (B, N, d)
            beta (_type_, optional): _description_. Defaults to None. shape (B,1)
            context (_type_, optional): _description_. Defaults to None.

        Returns:
            _type_: _description_
        """        ''''''
        
        # time_emb = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1)[:, None, :]  # (B, 3)
        # time_emb = time_emb.repeat(1, x.shape[1], 1)
        # x = self.embedding(torch.cat((x, time_emb), dim=-1)) 
        x = self.embedding(x) 
        pos_emb = self.pos_embedding(pos_emb.unsqueeze(-1))
        # print(x.shape, pos_emb.shape)
        x += pos_emb[None, ...]

        for layer_idx, layer in enumerate(self.trunk):

            if layer_idx != len(self.trunk) - 1:
                x = layer(x)
            else:
                x, attn = layer(x, return_attn_prenorm=True)
                
        return attn

######################################################################

class MyGPT(nn.Module):
    def __init__(self,
                 vocabulary_size,
                 dim_model, dim_keys, dim_hidden,
                 nb_heads, nb_blocks, dropout = 0.):

        super().__init__()

        assert dim_model % nb_heads == 0

        self.embedding = nn.Sequential(
            nn.Embedding(vocabulary_size, dim_model),
            nn.Dropout(dropout),
            PositionalEncoding(len_max = 1e5),
        )

        trunk_blocks = [ ]

        for _ in range(nb_blocks):
            trunk_blocks += [
                Residual(
                    nn.LayerNorm(dim_model),
                    QKVAttention(
                        dim_in = dim_model,
                        dim_qk = dim_keys, dim_v = dim_model // nb_heads,
                        nb_heads = nb_heads,
                        causal = True, attention_dropout = dropout
                    ),
                    nn.Linear(in_features = dim_model, out_features = dim_model),
                ),
                Residual(
                    nn.LayerNorm(dim_model),
                    nn.Linear(in_features = dim_model, out_features = dim_hidden),
                    nn.ReLU(),
                    nn.Linear(in_features = dim_hidden, out_features = dim_model),
                    nn.Dropout(dropout),
                ),
            ]

        self.trunk = nn.Sequential(*trunk_blocks)

        self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size)

    def forward(self, x, *args):
        x = self.embedding(x)
        x = self.trunk(x)
        x = self.readout(x)
        return x

######################################################################

if __name__ == '__main__':
    vocabulary_size = 10
    x = torch.randint(vocabulary_size, (25, 100))

    model = MyGPT(
        vocabulary_size = vocabulary_size,
        dim_model = 16, dim_keys = 50, dim_hidden = 100,
        nb_heads = 2, nb_blocks = 3,
        dropout = 0.1
    )

    y = model(x)

######################################################################