from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from pado.core import PadoModule, PadoModuleMixin
from pado.nn.modules import LayerNorm, GroupLayerNorm
from pado.nn.parameter import ParameterModule

__all__ = ["LSTM", "LSTMCell", "LayerNormLSTMCellOnce", "LayerNormLSTMCellEach", "scan_lstm_cell"]


class LSTM(nn.LSTM, PadoModuleMixin):
    # https://github.com/keitakurita/Better_LSTM_PyTorch/blob/master/better_lstm/model.py

    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 num_layers: int = 1,
                 bias: bool = True,
                 batch_first: bool = True,
                 dropout: float = 0.0,
                 bidirectional: bool = False, *,
                 init_forget_bias: bool = True,
                 variational_dropout: bool = False) -> None:
        nn.LSTM.__init__(self, input_dim, hidden_dim, num_layers, bias, batch_first,
                         dropout if (not variational_dropout) else 0.0, bidirectional)
        PadoModuleMixin.__init__(self)

        self.variational_dropout = dropout if variational_dropout else 0.0  # value
        self._initialize_parameters(init_forget_bias)

    def _initialize_parameters(self, init_forget_bias: bool = True):
        for p_name, p in self.named_parameters():
            if "weight_hh" in p_name:
                nn.init.orthogonal_(p)
            elif "weight_ih" in p_name:
                nn.init.xavier_uniform_(p)
            elif "bias" in p_name:
                nn.init.zeros_(p)
                if init_forget_bias:
                    p.data[self.hidden_size:self.hidden_size * 2] = 1.0  # forget gate

    def _variational_drop_weight(self):
        for p_name, p in self.named_parameters():
            if "weight_hh" in p_name:
                getattr(self, p_name).data = F.dropout(p.data, p=self.variational_dropout,
                                                       training=self.training).contiguous()

    def forward(self,
                x: torch.Tensor,
                state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
                ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if self.variational_dropout:
            self._variational_drop_weight()
        out, state = super().forward(x, state)
        return out, state


class LSTMCell(PadoModule):

    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 bias: bool = True,
                 dropout: float = 0.0,
                 *, init_forget_bias: bool = True,
                 variational_dropout: bool = False):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.weight_ih = ParameterModule(torch.empty(4 * hidden_dim, input_dim))
        self.weight_hh = ParameterModule(torch.empty(4 * hidden_dim, hidden_dim))

        self.use_bias = bias
        if bias:
            self.bias_ih = ParameterModule(torch.zeros(4 * hidden_dim))
            self.bias_hh = ParameterModule(torch.zeros(4 * hidden_dim))
        else:
            self.bias_ih = None
            self.bias_hh = None

        self.dropout = dropout if (not variational_dropout) else 0.0
        self.variational_dropout = dropout if variational_dropout else 0.0  # value
        self._initialize_parameters(init_forget_bias)

    def _initialize_parameters(self, init_forget_bias: bool = True):
        for p_name, p in self.named_parameters():
            if "weight" in p_name:
                nn.init.xavier_uniform_(p.data)
            elif "bias" in p_name:
                nn.init.zeros_(p.data)
                if init_forget_bias:
                    p.data[self.hidden_dim:self.hidden_dim * 2] = 1.0  # forget gate

    def generate_zero_state(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = x.shape[0]
        zero_hx = torch.zeros(batch_size, self.hidden_dim, dtype=x.dtype, device=x.device)
        zero_cx = torch.zeros(batch_size, self.hidden_dim, dtype=x.dtype, device=x.device)
        state = (zero_hx, zero_cx)
        return state

    def check_state(self,
                    x: torch.Tensor,
                    state: Tuple[torch.Tensor, torch.Tensor]) -> None:
        if len(state) != 2:
            raise ValueError(f"LSTMCell state should be tuple of two tensors.")

        batch_size = x.shape[0]
        if state[0].shape != (batch_size, self.hidden_dim):
            raise ValueError(f"LSTMCell state[0] shape {state[0].shape} should be ({batch_size}, {self.hidden_dim}).")
        if state[1].shape != (batch_size, self.hidden_dim):
            raise ValueError(f"LSTMCell state[1] shape {state[1].shape} should be ({batch_size}, {self.hidden_dim}).")

    def lstm_core(self,
                  x: torch.Tensor,
                  weight_ih: torch.Tensor,
                  weight_hh: torch.Tensor,
                  bias_ih: Optional[torch.Tensor],
                  bias_hh: Optional[torch.Tensor],
                  state: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        hx, cx = state

        # gate order: [in, forget, cell, out]
        ih_gates = torch.mm(x, weight_ih.transpose(0, 1))
        hh_gates = torch.mm(hx, weight_hh.transpose(0, 1))
        gates = ih_gates + hh_gates
        if self.use_bias:
            gates = gates + (bias_hh + bias_ih)

        in_gate, forget_gate, cell_gate, out_gate = torch.chunk(gates, 4, dim=1)
        in_gate = torch.sigmoid(in_gate)
        forget_gate = torch.sigmoid(forget_gate)
        cell_gate = torch.tanh(cell_gate)
        out_gate = torch.sigmoid(out_gate)

        cy = (forget_gate * cx) + (in_gate * cell_gate)
        hy = out_gate * torch.tanh(cy)
        return hy, cy

    def forward(self,
                x: torch.Tensor,
                state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
                ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        :param x:           (batch_size, hidden_dim)
        :param state:       (batch_size, hidden_dim) / (batch_size, hidden_dim)
        :return:
                            (batch_size, hidden_dim)
        """
        if state is None:
            state = self.generate_zero_state(x)
        else:
            self.check_state(x, state)

        weight_ih = self.weight_ih()
        weight_hh = F.dropout(self.weight_hh(), self.variational_dropout, training=self.training)  # variational drop

        bias_ih = self.bias_ih() if self.use_bias else None
        bias_hh = self.bias_hh() if self.use_bias else None

        hy, cy = self.lstm_core(x, weight_ih, weight_hh, bias_ih, bias_hh, state)

        new_state = (hy, cy)
        if self.dropout > 0 and (self.variational_dropout == 0):  # output drop
            hy = F.dropout(hy, self.dropout, training=self.training)
        return hy, new_state


class LayerNormLSTMCellOnce(LSTMCell):
    """
    Following https://arxiv.org/pdf/1909.12415.pdf,
    For efficiency, LN is merged, so should be slightly different.
    More similar to https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
    """

    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 bias: bool = False,
                 dropout: float = 0.0,
                 *, eps: float = 1e-5,
                 init_forget_bias: bool = True,
                 variational_dropout: bool = False):
        super().__init__(input_dim, hidden_dim, bias, dropout,
                         init_forget_bias=init_forget_bias,
                         variational_dropout=variational_dropout)
        self.ih_norm = LayerNorm(4 * self.hidden_dim, eps=eps)
        self.hh_norm = LayerNorm(4 * self.hidden_dim, eps=eps)
        self.candidate_norm = LayerNorm(self.hidden_dim, eps=eps)

    def lstm_core(self,
                  x: torch.Tensor,
                  weight_ih: torch.Tensor,
                  weight_hh: torch.Tensor,
                  bias_ih: Optional[torch.Tensor],
                  bias_hh: Optional[torch.Tensor],
                  state: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        hx, cx = state

        # gate order: [in, forget, cell, out]
        ih_gates = self.ih_norm(torch.mm(x, weight_ih.transpose(0, 1)))
        hh_gates = self.hh_norm(torch.mm(hx, weight_hh.transpose(0, 1)))
        gates = ih_gates + hh_gates
        if self.use_bias:
            gates = gates + (bias_hh + bias_ih)

        in_gate, forget_gate, cell_gate, out_gate = torch.chunk(gates, 4, dim=1)
        in_gate = torch.sigmoid(in_gate)
        forget_gate = torch.sigmoid(forget_gate)
        cell_gate = torch.tanh(cell_gate)
        out_gate = torch.sigmoid(out_gate)

        cy = (forget_gate * cx) + (in_gate * cell_gate)
        hy = out_gate * torch.tanh(self.candidate_norm(cy))
        return hy, cy


class LayerNormLSTMCellEach(LSTMCell):
    """
    Following https://arxiv.org/pdf/1909.12415.pdf,
    Unlike V1, LN is NOT merged.
    """

    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 bias: bool = False,
                 dropout: float = 0.0,
                 *, eps: float = 1e-5,
                 init_forget_bias: bool = True,
                 variational_dropout: bool = False):
        super().__init__(input_dim, hidden_dim, bias, dropout,
                         init_forget_bias=init_forget_bias,
                         variational_dropout=variational_dropout)
        self.ih_norm = GroupLayerNorm(4 * self.hidden_dim, num_groups=4, eps=eps)
        self.hh_norm = GroupLayerNorm(4 * self.hidden_dim, num_groups=4, eps=eps)
        self.candidate_norm = LayerNorm(self.hidden_dim, eps=eps)

    def lstm_core(self,
                  x: torch.Tensor,
                  weight_ih: torch.Tensor,
                  weight_hh: torch.Tensor,
                  bias_ih: Optional[torch.Tensor],
                  bias_hh: Optional[torch.Tensor],
                  state: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        hx, cx = state

        # gate order: [in, forget, cell, out]
        ih_gates = self.ih_norm(torch.mm(x, weight_ih.transpose(0, 1)))
        hh_gates = self.hh_norm(torch.mm(hx, weight_hh.transpose(0, 1)))
        gates = ih_gates + hh_gates
        if self.use_bias:
            gates = gates + (bias_hh + bias_ih)

        in_gate, forget_gate, cell_gate, out_gate = torch.chunk(gates, 4, dim=1)
        in_gate = torch.sigmoid(in_gate)
        forget_gate = torch.sigmoid(forget_gate)
        cell_gate = torch.tanh(cell_gate)
        out_gate = torch.sigmoid(out_gate)

        cy = (forget_gate * cx) + (in_gate * cell_gate)
        hy = out_gate * torch.tanh(self.candidate_norm(cy))
        return hy, cy


def scan_lstm_cell(cell: nn.Module,
                   x: torch.Tensor,
                   lengths: Optional[torch.Tensor] = None,
                   state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
                   ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    :param cell:        LSTMCell object
    :param x:           (batch_size, seq_length, input_dim)
    :param lengths:     (batch_size,)
    :param state:       [(batch_size, hidden_dim), (batch_size, hidden_dim)]
    :return:
            output:     (batch_size, hidden_dim)
            state:      [(batch_size, hidden_dim), (batch_size, hidden_dim)]
    """
    batch_size, seq_length = x.shape[:2]
    result = []

    if lengths is None:
        lengths = torch.full((batch_size,), fill_value=seq_length, dtype=torch.long, device=x.device)
    else:
        lengths = lengths.clone().to(x.device)

    if state is None:
        state = cell.generate_zero_state(x)
    else:
        cell.check_state(x, state)

    x = x.transpose(0, 1)  # (s, b, d)

    for i in range(seq_length):
        valid_flag = torch.greater(lengths, i).unsqueeze(-1)  # (b, 1)
        invalid_flag = torch.logical_not(valid_flag)

        _, (hy, cy) = cell(x[i], state)
        hy = hy * valid_flag + state[0] * invalid_flag
        cy = cy * valid_flag + state[1] * invalid_flag
        state = (hy, cy)
        result.append(hy)

    result = torch.stack(result, dim=1)  # (b, s, d)
    return result, state
