import copy
from typing import Optional, Union, Callable

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F


def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
 if activation == "relu":
 return F.relu
 elif activation == "gelu":
 return F.gelu

 raise RuntimeError(f"activation should be relu/gelu, not {activation}")


def _get_clones(module, n):
 return nn.ModuleList([copy.deepcopy(module) for i in range(n)])


def _get_seq_len(
 src: Tensor,
 batch_first: bool
) -> Optional[int]:

 if src.is_nested:
 return None
 else:
 src_size = src.size()
 if len(src_size) == 2:
 # unbatched: S, E
 return src_size[0]
 else:
 # batched: B, S, E if batch_first else S, B, E
 seq_len_pos = 1 if batch_first else 0
 return src_size[seq_len_pos]


def _detect_is_causal_mask(
 mask: Optional[Tensor],
 is_causal: Optional[bool] = None,
 size: Optional[int] = None,
) -> bool:
 """Return whether the given attention mask is causal.

 Warning:
 If ``is_causal`` is not ``None``, its value will be returned as is. If a
 user supplies an incorrect ``is_causal`` hint,

 ``is_causal=False`` when the mask is in fact a causal attention.mask
 may lead to reduced performance relative to what would be achievable
 with ``is_causal=True``;
 ``is_causal=True`` when the mask is in fact not a causal attention.mask
 may lead to incorrect and unpredictable execution - in some scenarios,
 a causal mask may be applied based on the hint, in other execution
 scenarios the specified mask may be used. The choice may not appear
 to be deterministic, in that a number of factors like alignment,
 hardware SKU, etc. influence the decision whether to use a mask or
 rely on the hint.
 ``size`` if not None, check whether the mask is a causal mask of the provided size
 Otherwise, checks for any causal mask.
 """
 # Prevent type refinement
 make_causal = (is_causal is True)

 if is_causal is None and mask is not None:
 sz = size if size is not None else mask.size(-2)
 causal_comparison = _generate_square_subsequent_mask(
 sz, device=mask.device, dtype=mask.dtype)

 # Do not use `torch.equal` so we handle batched masks by
 # broadcasting the comparison.
 if mask.size() == causal_comparison.size():
 make_causal = bool((mask == causal_comparison).all())
 else:
 make_causal = False

 return make_causal


def _generate_square_subsequent_mask(
 sz: int,
 device: torch.device = torch.device(torch._C._get_default_device()), # torch.device('cpu'),
 dtype: torch.dtype = torch.get_default_dtype(),
) -> Tensor:
 r"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
 Unmasked positions are filled with float(0.0).
 """
 return torch.triu(
 torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
 diagonal=1,
 )


def _weights_init(m):
 if isinstance(m, nn.Linear):
 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 if isinstance(m, nn.Conv1d):
 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 elif isinstance(m, nn.BatchNorm1d):
 nn.init.constant_(m.weight, 1)
 nn.init.constant_(m.bias, 0)



class TransformerEncoder(nn.Module):
 def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
 super().__init__()
 torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
 self.layers = _get_clones(encoder_layer, num_layers)
 self.num_layers = num_layers
 self.norm = norm

 def forward(
 self,
 src: Tensor,
 mask: Optional[Tensor] = None,
 src_key_padding_mask: Optional[Tensor] = None,
 is_causal: Optional[bool] = None) -> Tensor:

 output = src
 for mod in self.layers:
 output = mod(output, src_mask=mask)
 if self.norm is not None:
 output = self.norm(output)
 return output


class TransformerEncoderLayer(nn.Module):
 __constants__ = ['norm_first']

 def __init__(self, d_model: int, n_head: int, dim_ffn: int = 2048, dropout: float = 0.1,
 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
 bias: bool = True, device=None, dtype=None) -> None:
 factory_kwargs = {'device': device, 'dtype': dtype}
 super().__init__()
 self.self_attn_s = nn.MultiheadAttention(d_model // 2, n_head // 2, dropout=dropout,
 bias=bias, batch_first=batch_first,
 **factory_kwargs)
 self.self_attn_t = nn.MultiheadAttention(d_model // 2, n_head // 2, dropout=dropout,
 bias=bias, batch_first=batch_first,
 **factory_kwargs)

 # Implementation of Feedforward model
 self.linear1 = nn.Linear(d_model, dim_ffn, bias=bias, **factory_kwargs)
 self.dropout = nn.Dropout(dropout)
 self.linear2 = nn.Linear(dim_ffn, d_model, bias=bias, **factory_kwargs)

 self.norm_first = norm_first
 self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
 self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
 self.dropout1 = nn.Dropout(dropout)
 self.dropout2 = nn.Dropout(dropout)

 # Legacy string support for activation function.
 if isinstance(activation, str):
 activation = _get_activation_fn(activation)

 # We can't test self.activation in forward() in TorchScript,
 # so stash some information about it instead.
 if activation is F.relu or isinstance(activation, torch.nn.ReLU):
 self.activation_relu_or_gelu = 1
 elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
 self.activation_relu_or_gelu = 2
 else:
 self.activation_relu_or_gelu = 0
 self.activation = activation

 def __setstate__(self, state):
 super().__setstate__(state)
 if not hasattr(self, 'activation'):
 self.activation = F.relu


 def forward(
 self,
 src: Tensor,
 src_mask: Optional[Tensor] = None,
 src_key_padding_mask: Optional[Tensor] = None,
 is_causal: bool = False) -> Tensor:

 x = src
 x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
 x = x + self._ff_block(self.norm2(x))
 return x

 # self-attention block
 def _sa_block(self, x: Tensor,
 attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
 bz, ch_num, patch_num, patch_size = x.shape
 xs = x[:, :, :, :patch_size // 2]
 xt = x[:, :, :, patch_size // 2:]
 xs = xs.transpose(1, 2).contiguous().view(bz*patch_num, ch_num, patch_size // 2)
 xt = xt.contiguous().view(bz*ch_num, patch_num, patch_size // 2)
 xs = self.self_attn_s(xs, xs, xs,
 attn_mask=attn_mask,
 key_padding_mask=key_padding_mask,
 need_weights=False)[0]
 xs = xs.contiguous().view(bz, patch_num, ch_num, patch_size//2).transpose(1, 2)
 xt = self.self_attn_t(xt, xt, xt,
 attn_mask=attn_mask,
 key_padding_mask=key_padding_mask,
 need_weights=False)[0]
 xt = xt.contiguous().view(bz, ch_num, patch_num, patch_size//2)
 x = torch.concat((xs, xt), dim=3)
 return self.dropout1(x)

 # feed forward block
 def _ff_block(self, x: Tensor) -> Tensor:
 x = self.linear2(self.dropout(self.activation(self.linear1(x))))
 return self.dropout2(x)


class PatchEmbedding(nn.Module):
 def __init__(self, in_dim, d_model):
 super().__init__()
 self.d_model = d_model
 self.positional_encoding = nn.Sequential(
 nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(19, 7), stride=(1, 1), padding=(9, 3),
 groups=d_model),
 )
 self.mask_encoding = nn.Parameter(torch.zeros(in_dim), requires_grad=False)
 # self.mask_encoding = nn.Parameter(torch.randn(in_dim), requires_grad=True)

 self.proj_in = nn.Sequential(
 nn.Conv2d(in_channels=1, out_channels=25, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
 nn.GroupNorm(5, 25),
 nn.GELU(),

 nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
 nn.GroupNorm(5, 25),
 nn.GELU(),

 nn.Conv2d(in_channels=25, out_channels=25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
 nn.GroupNorm(5, 25),
 nn.GELU(),
 )
 self.spectral_proj = nn.Sequential(
 nn.Linear(101, d_model),
 nn.Dropout(0.1),
 # nn.LayerNorm(d_model, eps=1e-5),
 )
 # self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
 # self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
 # self.proj_in = nn.Sequential(
 # nn.Linear(in_dim, d_model, bias=False),
 # )


 def forward(self, x, mask=None):
 bz, ch_num, patch_num, patch_size = x.shape
 if mask is None:
 mask_x = x
 else:
 mask_x = x.clone()
 mask_x[mask == 1] = self.mask_encoding

 mask_x = mask_x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
 patch_emb = self.proj_in(mask_x)
 patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)

 mask_x = mask_x.contiguous().view(bz*ch_num*patch_num, patch_size)
 spectral = torch.fft.rfft(mask_x, dim=-1, norm='forward')
 spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
 spectral_emb = self.spectral_proj(spectral)
 # print(patch_emb[5, 5, 5, :])
 # print(spectral_emb[5, 5, 5, :])
 patch_emb = patch_emb + spectral_emb

 positional_embedding = self.positional_encoding(patch_emb.permute(0, 3, 1, 2))
 positional_embedding = positional_embedding.permute(0, 2, 3, 1)

 patch_emb = patch_emb + positional_embedding
 return patch_emb


class CBraMod(nn.Module):
 def __init__(
 self,
 in_dim: int=200,
 out_dim: int=200,
 d_model: int=200,
 dim_ffn: int=800,
 n_layer: int=12,
 n_head: int=8,
 ):
 super().__init__()
 self.patch_size = d_model
 self.out_dim = out_dim
 self.patch_embedding = PatchEmbedding(in_dim, d_model)
 encoder_layer = TransformerEncoderLayer(
 d_model=d_model, n_head=n_head, dim_ffn=dim_ffn, batch_first=True, norm_first=True,
 activation=F.gelu
 )
 self.encoder = TransformerEncoder(encoder_layer, num_layers=n_layer, enable_nested_tensor=False)
 self.proj_out = nn.Sequential(
 # nn.Linear(d_model, d_model*2),
 # nn.GELU(),
 # nn.Linear(d_model*2, d_model),
 # nn.GELU(),
 nn.Linear(d_model, out_dim),
 )
 self.apply(_weights_init)

 def forward(self, x, mask=None):
 patch_emb = self.patch_embedding(x, mask)
 feats = self.encoder(patch_emb)

 out = self.proj_out(feats)

 return out


if __name__ == '__main__':
 encoder_layer = TransformerEncoderLayer(
 d_model=256, n_head=4, dim_ffn=1024, batch_first=True, norm_first=True,
 activation=F.gelu
 )
 encoder = TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False)
 encoder = encoder.cuda()

 a = torch.randn((4, 19, 30, 256)).cuda()
 b = encoder(a)
 print(a.shape, b.shape)
