"""
This code is to implement the IndRNN (only the recurrent part). The code is based on the implementation from 
https://github.com/StefOe/indrnn-pytorch/blob/master/indrnn.py.
Since this only contains the recurrent part of IndRNN, fully connected layers or convolutional layers are needed before it.
Please cite the following paper if you find it useful.
Shuai Li, Wanqing Li, Chris Cook, Ce Zhu, and Yanbo Gao. "Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN," 
In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5457-5466. 2018.
@inproceedings{li2018independently,
  title={Independently recurrent neural network (indrnn): Building A longer and deeper RNN},
  author={Li, Shuai and Li, Wanqing and Cook, Chris and Zhu, Ce and Gao, Yanbo},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  pages={5457--5466},
  year={2018}
}
"""


import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math


class IndRNNCell_onlyrecurrent(nn.Module):
    r"""An IndRNN cell with ReLU non-linearity. This is only the recurrent part where the input is already processed with w_{ih} * x + b_{ih}.

    .. math::
        input=w_{ih} * x + b_{ih}
        h' = \relu(input +  w_{hh} (*) h)
    With (*) being element-wise vector multiplication.

    Args:
        hidden_size: The number of features in the hidden state h

    Inputs: input, hidden
        - **input** (batch, input_size): tensor containing input features
        - **hidden** (batch, hidden_size): tensor containing the initial hidden
          state for each element in the batch.

    Outputs: h'
        - **h'** (batch, hidden_size): tensor containing the next hidden state
          for each element in the batch
    """

    def __init__(self, hidden_size, 
                 hidden_max_abs=None, recurrent_init=None):
        super(IndRNNCell_onlyrecurrent, self).__init__()
        self.hidden_size = hidden_size
        self.recurrent_init = recurrent_init
        self.weight_hh = Parameter(torch.Tensor(hidden_size))            
        self.reset_parameters()

    def reset_parameters(self):
        for name, weight in self.named_parameters():
            if "weight_hh" in name:
                if self.recurrent_init is None:
                    nn.init.uniform(weight, a=0, b=1)
                else:
                    self.recurrent_init(weight)

    def forward(self, input, hx):
        return F.relu(input + hx * self.weight_hh.unsqueeze(0).expand(hx.size(0), len(self.weight_hh)))


class IndRNN_onlyrecurrent(nn.Module):
    r"""Applies an IndRNN with `ReLU` non-linearity to an input sequence. 
    This is only the recurrent part where the input is already processed with w_{ih} * x + b_{ih}.


    For each element in the input sequence, each layer computes the following
    function:

    .. math::

        h_t = \relu(input_t +  w_{hh} (*) h_{(t-1)})

    where :math:`h_t` is the hidden state at time `t`, and :math:`input_t`
    is the input at time `t`. (*) is element-wise multiplication.

    Args:
        hidden_size: The number of features in the hidden state `h`        

    Inputs: input, h_0
        - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
          of the input sequence. The input can also be a packed variable length
          sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
          or :func:`torch.nn.utils.rnn.pack_sequence`
          for details.
        - **h_0** of shape `( batch, hidden_size)`: tensor
          containing the initial hidden state for each element in the batch.
          Defaults to zero if not provided.

    Outputs: output  
        - **output** of shape `(seq_len, batch, hidden_size)`
    """

    def __init__(self, hidden_size,recurrent_init=None, **kwargs):
        super(IndRNN_onlyrecurrent, self).__init__()
        self.hidden_size = hidden_size
        self.indrnn_cell=IndRNNCell_onlyrecurrent(hidden_size, **kwargs)

        if recurrent_init is not None:
            kwargs["recurrent_init"] = recurrent_init
        self.recurrent_init=recurrent_init
        # h0 = torch.zeros(hidden_size * num_directions)
        # self.register_buffer('h0', torch.autograd.Variable(h0))
        self.reset_parameters()

    def reset_parameters(self):
        for name, weight in self.named_parameters():
            if "weight_hh" in name:
                if self.recurrent_init is None:
                    nn.init.uniform(weight, a=0, b=1)
                else:
                    self.recurrent_init(weight)

    def forward(self, input, h0=None):
        assert input.dim() == 2 or input.dim() == 3        
        if h0 is None:
            h0 = input.data.new(input.size(-2),input.size(-1)).zero_().contiguous()
        elif (h0.size(-1)!=input.size(-1)) or (h0.size(-2)!=input.size(-2)):
            raise RuntimeError(
                'The initial hidden size must be equal to input_size. Expected {}, got {}'.format(
                    h0.size(), input.size()))
        outputs=[]
        hx_cell=h0
        for input_t in input:
            hx_cell = self.indrnn_cell(input_t, hx_cell)
            outputs.append(hx_cell)
        out_put = torch.stack(outputs, 0)
        return out_put

class AdaptiveIndRNN_onlyrecurrent(nn.Module):
    r"""Applies an IndRNN with `ReLU` non-linearity to an input sequence. 
    This is only the recurrent part where the input is already processed with w_{ih} * x + b_{ih}.


    For each element in the input sequence, each layer computes the following
    function:

    .. math::

        h_t = \relu(input_t +  w_{hh} (*) h_{(t-1)})

    where :math:`h_t` is the hidden state at time `t`, and :math:`input_t`
    is the input at time `t`. (*) is element-wise multiplication.

    Args:
        hidden_size: The number of features in the hidden state `h`        

    Inputs: input, h_0
        - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
          of the input sequence. The input can also be a packed variable length
          sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
          or :func:`torch.nn.utils.rnn.pack_sequence`
          for details.
        - **h_0** of shape `( batch, hidden_size)`: tensor
          containing the initial hidden state for each element in the batch.
          Defaults to zero if not provided.

    Outputs: output  
        - **output** of shape `(seq_len, batch, hidden_size)`
    """

    def __init__(self, hidden_size,recurrent_init=None, hard=False, **kwargs):
        super(AdaptiveIndRNN_onlyrecurrent, self).__init__()
        self.hidden_size = hidden_size
        self.indrnn_cell = IndRNNCell_onlyrecurrent(hidden_size, **kwargs)

        self.selector_layer = None
        self.hard = hard

        if recurrent_init is not None:
            kwargs["recurrent_init"] = recurrent_init
        self.recurrent_init=recurrent_init
        # h0 = torch.zeros(hidden_size * num_directions)
        # self.register_buffer('h0', torch.autograd.Variable(h0))
        self.reset_parameters()

    def reset_parameters(self):
        for name, weight in self.named_parameters():
            if "weight_hh" in name:
                if self.recurrent_init is None:
                    nn.init.uniform(weight, a=0, b=1)
                else:
                    self.recurrent_init(weight)

    def sample_gumbel(self, shape, eps=1e-20):
        U = torch.rand(shape).cuda()
        return -Variable(torch.log(-torch.log(U + eps) + eps))

    def gumbel_softmax_sample(self, logits, temperature, hard=False, deterministic=False, eps=1e-20):
        if deterministic:
            if logits.shape[-1] == 1:
                return F.sigmoid(logits)
            else:
                return F.softmax(logits, dim=-1)
        
        # Stochastic
        if logits.shape[-1] == 1:
            noise = torch.rand_like(logits)
            y = (logits + torch.log(noise + eps) - torch.log(1 - noise + eps))            
            y = torch.sigmoid(y / temperature)
            if hard:
                return (y > 0.5).float()
            else:
                return y
        else:
            y = logits + self.sample_gumbel(logits.size())
            y = F.softmax(y / temperature, dim=-1)
            if hard:
                return (y > 0.5).float()
            else:
                return y

    def forward(self, input, h0=None):
        assert input.dim() == 2 or input.dim() == 3        
        if h0 is None:
            h0 = input.data.new(input.size(-2),input.size(-1)).zero_().contiguous()
        elif (h0.size(-1)!=input.size(-1)) or (h0.size(-2)!=input.size(-2)):
            raise RuntimeError(
                'The initial hidden size must be equal to input_size. Expected {}, got {}'.format(
                    h0.size(), input.size()))
        if self.selector_layer == None:
            self.selector_layer = nn.Linear(self.hidden_size, input.size(-1)).cuda()
            # self.selector_layer = nn.Linear(self.hidden_size, input.size(-1)//3).cuda()
            self.selector_layer.requires_grad = True

        hx_cell = h0
        outputs = []
        selection_logits = []
        self.selection_weights = []
        self.num_selections = 0
        for t, input_t in enumerate(input):
            if t == 0:
                weights = torch.ones_like(input_t)
                self.selection_weights.append(weights)
            else:
                # Feature selection
                temp = 0.05
                sel_log = selection_logits[-1].unsqueeze(-1) # batch x feature x 1
                weights = self.gumbel_softmax_sample(sel_log, temp, hard=self.hard)
                weights = weights.squeeze(-1) # batch x feature
                # weights = torch.stack(3 * [weights], dim=-1)
                # weights = weights.view(weights.shape[0], weights.shape[1] * weights.shape[2])
                self.selection_weights.append(weights)

                input_t = input_t * weights
                self.num_selections += torch.sum(weights)

            hx_cell = self.indrnn_cell(input_t, hx_cell)
            curr_logits = self.selector_layer(hx_cell)
            selection_logits.append(curr_logits)

            outputs.append(hx_cell)

        # Divide by time * batch * num_features
        self.num_selections /= (input.shape[0] * input.shape[1] * input.shape[2])

        out_put = torch.stack(outputs, 0)
        return out_put

    def regularizer(self):
        return self.num_selections
    
    def get_selection_weights(self):
        s_w = []
        for w in self.selection_weights:
            s_w.append(w.detach().cpu().numpy())
        return s_w

class AttentionIndRNN_onlyrecurrent(nn.Module):
    r"""Applies an IndRNN with `ReLU` non-linearity to an input sequence. 
    This is only the recurrent part where the input is already processed with w_{ih} * x + b_{ih}.


    For each element in the input sequence, each layer computes the following
    function:

    .. math::

        h_t = \relu(input_t +  w_{hh} (*) h_{(t-1)})

    where :math:`h_t` is the hidden state at time `t`, and :math:`input_t`
    is the input at time `t`. (*) is element-wise multiplication.

    Args:
        hidden_size: The number of features in the hidden state `h`        

    Inputs: input, h_0
        - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
          of the input sequence. The input can also be a packed variable length
          sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
          or :func:`torch.nn.utils.rnn.pack_sequence`
          for details.
        - **h_0** of shape `( batch, hidden_size)`: tensor
          containing the initial hidden state for each element in the batch.
          Defaults to zero if not provided.

    Outputs: output  
        - **output** of shape `(seq_len, batch, hidden_size)`
    """

    def __init__(self, hidden_size,recurrent_init=None, hard=False, threshold=0.5, **kwargs):
        super(AttentionIndRNN_onlyrecurrent, self).__init__()
        self.hidden_size = hidden_size
        self.indrnn_cell = IndRNNCell_onlyrecurrent(hidden_size, **kwargs)

        self.selector_layer = None
        self.hard = hard
        self.threshold = threshold

        if recurrent_init is not None:
            kwargs["recurrent_init"] = recurrent_init
        self.recurrent_init=recurrent_init
        # h0 = torch.zeros(hidden_size * num_directions)
        # self.register_buffer('h0', torch.autograd.Variable(h0))
        self.reset_parameters()

    def reset_parameters(self):
        for name, weight in self.named_parameters():
            if "weight_hh" in name:
                if self.recurrent_init is None:
                    nn.init.uniform(weight, a=0, b=1)
                else:
                    self.recurrent_init(weight)

    def attention_weights(self, logits, hard=False, threshold=0.5):
        if logits.shape[-1] == 1:
            y = F.sigmoid(logits)
        else:
            y = F.softmax(logits, dim=-1)

        if hard:
            return (y > threshold).float()
        else:
            return y

    def forward(self, input, h0=None):
        assert input.dim() == 2 or input.dim() == 3        
        if h0 is None:
            h0 = input.data.new(input.size(-2),input.size(-1)).zero_().contiguous()
        elif (h0.size(-1)!=input.size(-1)) or (h0.size(-2)!=input.size(-2)):
            raise RuntimeError(
                'The initial hidden size must be equal to input_size. Expected {}, got {}'.format(
                    h0.size(), input.size()))
        if self.selector_layer == None:
            self.selector_layer = nn.Linear(self.hidden_size, input.size(-1)).cuda()
            # self.selector_layer = nn.Linear(self.hidden_size, input.size(-1) // 3).cuda()
            self.selector_layer.requires_grad = True

        hx_cell = h0
        outputs = []
        selection_logits = []
        self.selection_weights = []
        self.num_selections = 0
        for t, input_t in enumerate(input):
            if t == 0:
                weights = torch.ones_like(input_t)
                self.selection_weights.append(weights)
            else:
                # Feature selection
                temp = 0.05
                sel_log = selection_logits[-1].unsqueeze(-1) # batch x feature x 1
                weights = self.attention_weights(sel_log, hard=self.hard, threshold=self.threshold)
                weights = weights.squeeze(-1) # batch x feature
                # weights = torch.stack(3 * [weights], dim=-1)
                # weights = weights.view(weights.shape[0], weights.shape[1] * weights.shape[2])
                self.selection_weights.append(weights)

                input_t = input_t * weights
                self.num_selections += torch.sum(weights)

            hx_cell = self.indrnn_cell(input_t, hx_cell)
            curr_logits = self.selector_layer(hx_cell)
            selection_logits.append(curr_logits)

            outputs.append(hx_cell)

        # Divide by time * batch * num_features
        self.num_selections /= (input.shape[0] * input.shape[1] * input.shape[2])

        out_put = torch.stack(outputs, 0)
        return out_put

    def regularizer(self):
        return self.num_selections
    
    def get_selection_weights(self):
        s_w = []
        for w in self.selection_weights:
            s_w.append(w.detach().cpu().numpy())
        return s_w
