import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as nn_init
import torch.nn.functional as F
from torch import Tensor

import math

class Tokenizer(nn.Module):

    def __init__(self, d_numerical, categories, d_token, bias):
        super().__init__()
        if categories is None:
            d_bias = d_numerical
            self.category_offsets = None
            self.category_embeddings = None
        else:
            d_bias = d_numerical + len(categories)
            category_offsets = torch.tensor([0] + list(categories[:-1])).cumsum(0)
            self.register_buffer('category_offsets', category_offsets)
            self.cat_weight = nn.Parameter(Tensor(sum(categories), d_token))
            nn.init.kaiming_uniform_(self.cat_weight, a=math.sqrt(5))

        # take [CLS] token into account
        self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token))
        self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
        # The initialization is inspired by nn.Linear
        nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))

    @property
    def n_tokens(self):
        return len(self.weight) + (
            0 if self.category_offsets is None else len(self.category_offsets)
        )

    def forward(self, x_num, x_cat):
        x_some = x_num if x_cat is None else x_cat
        assert x_some is not None
        x_num = torch.cat(
            [torch.ones(len(x_some), 1, device=x_some.device)]  # [CLS]
            + ([] if x_num is None else [x_num]),
            dim=1,
        )
    
        x = self.weight[None] * x_num[:, :, None]

        if x_cat is not None:
            for start, end in zip(self.category_offsets, torch.cat([self.category_offsets[1:], torch.tensor([x_cat.shape[1]], device=x_cat.device)])):
                if start < end:
                    x = torch.cat(
                        [x, x_cat[:, start:end].unsqueeze(1) @ self.cat_weight[start:end][None]],
                        dim=1,
                    )
        if self.bias is not None:
            bias = torch.cat(
                [
                    torch.zeros(1, self.bias.shape[1], device=x.device),
                    self.bias,
                ]
            )
            x = x + bias[None]

        return x


# class MultiheadAttention(nn.Module):
#     def __init__(self, d, n_heads, dropout, initialization = 'kaiming'):
#
#         if n_heads > 1:
#             assert d % n_heads == 0
#         assert initialization in ['xavier', 'kaiming']
#
#         super().__init__()
#         self.W_q = nn.Linear(d, d)
#         self.W_k = nn.Linear(d, d)
#         self.W_v = nn.Linear(d, d)
#         self.W_out = nn.Linear(d, d) if n_heads > 1 else None
#         self.n_heads = n_heads
#         self.dropout = nn.Dropout(dropout) if dropout else None
#
#         for m in [self.W_q, self.W_k, self.W_v]:
#             if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
#                 # gain is needed since W_qkv is represented with 3 separate layers
#                 nn_init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
#             nn_init.zeros_(m.bias)
#         if self.W_out is not None:
#             nn_init.zeros_(self.W_out.bias)
#
#     def _reshape(self, x):
#         batch_size, n_tokens, d = x.shape
#         d_head = d // self.n_heads
#         return (
#             x.reshape(batch_size, n_tokens, self.n_heads, d_head)
#             .transpose(1, 2)
#             .reshape(batch_size * self.n_heads, n_tokens, d_head)
#         )
#
#     def forward(self, x_q, x_kv, key_compression = None, value_compression = None):
#
#         q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
#         for tensor in [q, k, v]:
#             assert tensor.shape[-1] % self.n_heads == 0
#         if key_compression is not None:
#             assert value_compression is not None
#             k = key_compression(k.transpose(1, 2)).transpose(1, 2)
#             v = value_compression(v.transpose(1, 2)).transpose(1, 2)
#         else:
#             assert value_compression is None
#
#         batch_size = len(q)
#         d_head_key = k.shape[-1] // self.n_heads
#         d_head_value = v.shape[-1] // self.n_heads
#         n_q_tokens = q.shape[1]
#
#         q = self._reshape(q)
#         k = self._reshape(k)
#
#         a = q @ k.transpose(1, 2)
#         b = math.sqrt(d_head_key)
#         attention = F.softmax(a/b , dim=-1)
#
#
#         if self.dropout is not None:
#             attention = self.dropout(attention)
#         x = attention @ self._reshape(v)
#         x = (
#             x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
#             .transpose(1, 2)
#             .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
#         )
#         if self.W_out is not None:
#             x = self.W_out(x)
#
#         return x

class MultiheadAttention(nn.Module):
    def __init__(self, d, n_heads, dropout, initialization='kaiming'):

        if n_heads > 1:
            assert d % n_heads == 0
        assert initialization in ['xavier', 'kaiming']

        super().__init__()
        self.W_q = nn.Linear(d, d)
        self.W_k = nn.Linear(d, d)
        self.W_v = nn.Linear(d, d)
        self.W_out = nn.Linear(d, d) if n_heads > 1 else None
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout) if dropout else None

        for m in [self.W_q, self.W_k, self.W_v]:
            if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
                # gain은 W_q, W_k, W_v를 별도 레이어로 사용하는 경우 필요한 값
                nn_init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
            nn_init.zeros_(m.bias)
        if self.W_out is not None:
            nn_init.zeros_(self.W_out.bias)

    def _reshape(self, x):
        """
        x: (batch_size, n_tokens, d)
        return: (batch_size * n_heads, n_tokens, d // n_heads)
        """
        batch_size, n_tokens, d = x.shape
        d_head = d // self.n_heads
        return (
            x.reshape(batch_size, n_tokens, self.n_heads, d_head)
            .transpose(1, 2)
            .reshape(batch_size * self.n_heads, n_tokens, d_head)
        )

    def forward(self, x_q, x_kv, q_padding_mask=None, kv_padding_mask=None,
                key_compression=None, value_compression=None):
        q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
        for tensor in [q, k, v]:
            assert tensor.shape[-1] % self.n_heads == 0

        if key_compression is not None:
            assert value_compression is not None
            k = key_compression(k.transpose(1, 2)).transpose(1, 2)
            v = value_compression(v.transpose(1, 2)).transpose(1, 2)
        else:
            assert value_compression is None

        batch_size = q.size(0)
        n_q_tokens = q.size(1)
        d_head_key = k.shape[-1] // self.n_heads
        d_head_value = v.shape[-1] // self.n_heads

        # (batch_size * n_heads, n_q_tokens, d_head)
        q = self._reshape(q)
        k = self._reshape(k)
        v = self._reshape(v)

        a = q @ k.transpose(1, 2)
        a_4d = a.view(batch_size, self.n_heads, n_q_tokens, -1)

        if q_padding_mask is not None:
            q_mask_4d = q_padding_mask.unsqueeze(1).unsqueeze(3)  # (B, 1, n_q_tokens, 1)
            a_4d = a_4d.masked_fill(q_mask_4d.bool(), float('-inf'))

        if kv_padding_mask is not None:
            kv_mask_4d = kv_padding_mask.unsqueeze(1).unsqueeze(2)  # (B, 1, 1, n_kv_tokens)
            a_4d = a_4d.masked_fill(kv_mask_4d.bool(), float('-inf'))

        a = a_4d.view(batch_size * self.n_heads, n_q_tokens, -1)

        scale = math.sqrt(d_head_key)
        attention = F.softmax(a / scale, dim=-1).nan_to_num()

        if self.dropout is not None:
            attention = self.dropout(attention)

        x = attention @ v  # -> (B*n_heads, n_q_tokens, d_head_value)

        x = (
            x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
             .transpose(1, 2)
             .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
        )

        if self.W_out is not None:
            x = self.W_out(x)
        return x
        
# class Transformer(nn.Module):
#
#     def __init__(
#         self,
#         n_layers: int,
#         d_token: int,
#         n_heads: int,
#         d_out: int,
#         d_ffn_factor: int,
#         attention_dropout = 0.0,
#         ffn_dropout = 0.0,
#         residual_dropout = 0.0,
#         activation = 'relu',
#         prenormalization = True,
#         initialization = 'kaiming',
#     ):
#         super().__init__()
#
#         def make_normalization():
#             return nn.LayerNorm(d_token)
#
#         d_hidden = int(d_token * d_ffn_factor)
#         self.layers = nn.ModuleList([])
#         for layer_idx in range(n_layers):
#             layer = nn.ModuleDict(
#                 {
#                     'attention': MultiheadAttention(
#                         d_token, n_heads, attention_dropout, initialization
#                     ),
#                     'linear0': nn.Linear(
#                         d_token, d_hidden
#                     ),
#                     'linear1': nn.Linear(d_hidden, d_token),
#                     'norm1': make_normalization(),
#                 }
#             )
#             if not prenormalization or layer_idx:
#                 layer['norm0'] = make_normalization()
#
#             self.layers.append(layer)
#
#         self.activation = nn.ReLU()
#         self.last_activation = nn.ReLU()
#         # self.activation = lib.get_activation_fn(activation)
#         # self.last_activation = lib.get_nonglu_activation_fn(activation)
#         self.prenormalization = prenormalization
#         self.last_normalization = make_normalization() if prenormalization else None
#         self.ffn_dropout = ffn_dropout
#         self.residual_dropout = residual_dropout
#         self.head = nn.Linear(d_token, d_out)
#
#
#     def _start_residual(self, x, layer, norm_idx):
#         x_residual = x
#         if self.prenormalization:
#             norm_key = f'norm{norm_idx}'
#             if norm_key in layer:
#                 x_residual = layer[norm_key](x_residual)
#         return x_residual
#
#     def _end_residual(self, x, x_residual, layer, norm_idx):
#         if self.residual_dropout:
#             x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
#         x = x + x_residual
#         if not self.prenormalization:
#             x = layer[f'norm{norm_idx}'](x)
#         return x
#
#     def forward(self, x):
#         for layer_idx, layer in enumerate(self.layers):
#             is_last_layer = layer_idx + 1 == len(self.layers)
#
#             x_residual = self._start_residual(x, layer, 0)
#             x_residual = layer['attention'](
#                 # for the last attention, it is enough to process only [CLS]
#                 x_residual,
#                 x_residual,
#             )
#
#             x = self._end_residual(x, x_residual, layer, 0)
#
#             x_residual = self._start_residual(x, layer, 1)
#             x_residual = layer['linear0'](x_residual)
#             x_residual = self.activation(x_residual)
#             if self.ffn_dropout:
#                 x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
#             x_residual = layer['linear1'](x_residual)
#             x = self._end_residual(x, x_residual, layer, 1)
#         return x


class Transformer(nn.Module):

    def __init__(
            self,
            n_layers: int,
            d_token: int,
            n_heads: int,
            d_out: int,
            d_ffn_factor: int,
            attention_dropout=0.0,
            ffn_dropout=0.0,
            residual_dropout=0.0,
            activation='relu',
            prenormalization=True,
            initialization='kaiming',
    ):
        super().__init__()

        def make_normalization():
            return nn.LayerNorm(d_token)

        d_hidden = int(d_token * d_ffn_factor)
        self.layers = nn.ModuleList([])
        for layer_idx in range(n_layers):
            layer = nn.ModuleDict(
                {
                    'attention': MultiheadAttention(
                        d_token, n_heads, attention_dropout, initialization
                    ),
                    'linear0': nn.Linear(
                        d_token, d_hidden
                    ),
                    'linear1': nn.Linear(d_hidden, d_token),
                    'norm1': make_normalization(),
                }
            )
            if not prenormalization or layer_idx:
                layer['norm0'] = make_normalization()

            self.layers.append(layer)

        self.activation = nn.ReLU()
        self.last_activation = nn.ReLU()
        # self.activation = lib.get_activation_fn(activation)
        # self.last_activation = lib.get_nonglu_activation_fn(activation)
        self.prenormalization = prenormalization
        self.last_normalization = make_normalization() if prenormalization else None
        self.ffn_dropout = ffn_dropout
        self.residual_dropout = residual_dropout
        self.head = nn.Linear(d_token, d_out)

    def _start_residual(self, x, layer, norm_idx):
        x_residual = x
        if self.prenormalization:
            norm_key = f'norm{norm_idx}'
            if norm_key in layer:
                x_residual = layer[norm_key](x_residual)
        return x_residual

    def _end_residual(self, x, x_residual, layer, norm_idx):
        if self.residual_dropout:
            x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
        x = x + x_residual
        if not self.prenormalization:
            x = layer[f'norm{norm_idx}'](x)
        return x

    def forward(self, x, padding_mask=None):
        for layer_idx, layer in enumerate(self.layers):
            is_last_layer = layer_idx + 1 == len(self.layers)

            x_residual = self._start_residual(x, layer, 0)
            x_residual = layer['attention'](
                # for the last attention, it is enough to process only [CLS]
                x_residual,
                x_residual,
                padding_mask,
                padding_mask
            )

            x = self._end_residual(x, x_residual, layer, 0)

            x_residual = self._start_residual(x, layer, 1)
            x_residual = layer['linear0'](x_residual)
            x_residual = self.activation(x_residual)
            if self.ffn_dropout:
                x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
            x_residual = layer['linear1'](x_residual)
            x = self._end_residual(x, x_residual, layer, 1)
        return x


class Reconstructor(nn.Module):
    def __init__(self, d_numerical, categories, d_token):
        super(Reconstructor, self).__init__()

        self.d_numerical = d_numerical
        self.categories = categories
        self.d_token = d_token
        
        self.weight = nn.Parameter(Tensor(d_numerical, d_token))  
        nn.init.xavier_uniform_(self.weight, gain=1 / math.sqrt(2))
        self.cat_recons = nn.ModuleList()

        for d in categories:
            recon = nn.Linear(d_token, d)
            nn.init.xavier_uniform_(recon.weight, gain=1 / math.sqrt(2))
            self.cat_recons.append(recon)

    def forward(self, h):
        h_num  = h[:, :self.d_numerical]
        h_cat  = h[:, self.d_numerical:]

        recon_x_num = torch.mul(h_num, self.weight.unsqueeze(0)).sum(-1)
        recon_x_cat = []

        for i, recon in enumerate(self.cat_recons):
      
            recon_x_cat.append(recon(h_cat[:, i]))

        return recon_x_num, recon_x_cat
