import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from tianshou.utils.net.common import MLP
from einops import rearrange, pack


class SeqFeatAttn(nn.Module):

  def __init__(self, dmodel, nheads, dropout=0.0):
    super().__init__()

    self.nheads = nheads
    self.dmodel = dmodel

    # Linear layers for Q, K, V projections
    if nheads > 1:
      self.W_Q = nn.Linear(dmodel, dmodel * nheads, bias=False)
      self.W_K = nn.Linear(dmodel, dmodel * nheads, bias=False)
      self.W_V = nn.Linear(dmodel, dmodel * nheads, bias=False)
      # Output linear layer
      self.W_O = nn.Linear(dmodel * nheads, dmodel, bias=False)
    else:
      self.W_Q = self.W_K = self.W_V = self.W_O = None

    self.dropout = nn.Dropout(dropout)

  def forward(self, Q, K, V=None):
    V = K if V is None else V

    # ndim=2 Q:[bsz, dmodel] KV: [bsz, dmodel] ->
    # ndim=3 Q:[bsz, t, dmodel] KV: [bsz, s, dmodel]
    ndim_list = [i.ndim for i in [Q, K, V]]
    Q, K, V = [
        rearrange(i, 'b d->b 1 d') if i.ndim == 2 else i for i in [Q, K, V]
    ]

    # Linear projections
    Q_proj = self.W_Q(Q) if self.nheads > 1 else Q
    K_proj = self.W_K(K) if self.nheads > 1 else K
    V_proj = self.W_V(V) if self.nheads > 1 else V

    # Reshape for multi-head operation
    Q_proj_reshaped = rearrange(Q_proj, 'b t (h d) -> b h t d', h=self.nheads)
    K_proj_reshaped = rearrange(K_proj, 'b s (h d) -> b h s d', h=self.nheads)
    V_proj_reshaped = rearrange(V_proj, 'b s (h d) -> b h s d', h=self.nheads)

    # Feature Wise Attention
    # Q attends to all K
    scores = torch.einsum('bhti,bhsj->bhtsij', Q_proj_reshaped,
                          K_proj_reshaped) / torch.sqrt(
                              torch.tensor(self.dmodel).float())
    # Softmax Normalization over the last dimension
    alpha = F.softmax(scores, dim=-1)
    # Compute Attended Values
    attended = torch.einsum('bhtsij,bhsj->bhtsi', alpha, V_proj_reshaped)

    # out projection
    attended_reshaped = rearrange(attended, 'b h t s d -> b t s (h d)')
    # [b t s d]
    output = self.W_O(
        attended_reshaped) if self.nheads > 1 else attended_reshaped

    # restore 2D: [b t s d] -> [b, d]
    if sum(ndim_list) == 6: # all inputs were 2D
      output = rearrange(output, 'b t s d->b (t s d)')

    return self.dropout(output)


if __name__ == '__main__':
  #* Pseudo Data
  # Re-run the steps
  bsz = 512
  dmodel = 40
  S_len = 2
  T_len = 4
  nheads = 1
  dropout = 0

  Q, K, V = torch.rand((bsz, T_len, dmodel)), torch.rand(
      (bsz, S_len, dmodel)), torch.rand((bsz, S_len, dmodel))

  feat_seq_attn = SeqFeatAttn(dmodel=dmodel, nheads=nheads, dropout=dropout)

  def true_feat_attn(Q, K, V):
    # QKV: [bsz, dmodel]
    # todo: remove Q K V's seq_len dim if Q K V is 3d rather than 2d
    bsz, d_model = Q.shape
    scores = torch.einsum('bi,bj->bij', Q, K) / (d_model**0.5)
    alpha = F.softmax(scores, dim=-1)
    attended = torch.einsum('bij,bj->bi', alpha, V)
    return attended

  t_list = []
  for t in range(T_len):
    s_list = []
    Qt = Q[:, t, :]
    for s in range(S_len):
      Ks = K[:, s, :]
      Vs = V[:, s, :]
      res = feat_seq_attn(Qt, Ks, Vs)
      attended_true_no_dropout = true_feat_attn(Qt, Ks, Vs)
      assert (res == attended_true_no_dropout).sum() == torch.prod(
          torch.tensor(list(res.shape)))
      s_list.append(attended_true_no_dropout)
    t_list.append(pack(s_list, '* b d')[0])
  ts_result = pack(t_list, '* s b d')[0]
  ts_result = rearrange(ts_result, 't s b d -> b t s d')

  result = feat_seq_attn(Q, K, V)
  assert (ts_result == result).sum() == torch.prod(
      torch.tensor(list(ts_result.shape)))
