import torch
import numpy as np
import torch.nn as nn

def fanin_init(size, fanin=None):
    """
    weight initializer known from https://arxiv.org/abs/1502.01852
    :param size:
    :param fanin:
    :return:
    """
    fanin = fanin or size[0]
    v = np.sqrt(2./fanin)
    return torch.Tensor(size).uniform_(-v, v)


class LSTME(nn.Module):
    def __init__(self, input_size=2, hidden_layer_size=100, state_size=1):
        super().__init__()
        torch.manual_seed(0)
        self.hidden_layer_size = hidden_layer_size

        self.lstm = nn.LSTM(input_size, hidden_layer_size, batch_first=True)

    def forward(self, input_seq):
        if torch.cuda.is_available():
            self.hidden_cell = (torch.zeros(1,input_seq.shape[0] , self.hidden_layer_size).cuda(),
                                torch.zeros(1, input_seq.shape[0], self.hidden_layer_size).cuda())
        else:
            self.hidden_cell = (torch.zeros(1, input_seq.shape[0], self.hidden_layer_size),
                                torch.zeros(1, input_seq.shape[0], self.hidden_layer_size))
        lstm_out, self.hidden_cell = self.lstm(input_seq, self.hidden_cell)

        #lstm_out,perm_idx=torch.nn.utils.rnn.pad_packed_sequence(lstm_out,batch_first=True)
        #_,unperm_idx=perm_idx.sort(0)
        #lstm_out=lstm_out[unperm_idx]
        return lstm_out[:,-1,:]


class Critic(nn.Module):
    def __init__(self, Z_dim, action_dim, h1=50, h2=5, eps=0.03):
        super(Critic, self).__init__()
        torch.manual_seed(0)

        self.state_dim = Z_dim
        self.action_dim = action_dim

        self.fc1 = nn.Linear(Z_dim, h1)
        self.fc1.weight.data = fanin_init(self.fc1.weight.data.size())

        self.fc2 = nn.Linear(h1 + action_dim, h2)
        self.fc2.weight.data = fanin_init(self.fc2.weight.data.size())

        self.fc3 = nn.Linear(h2, 1)
        self.fc3.weight.data.uniform_(-eps, eps)

        self.relu = nn.ReLU()

        self.lstm=LSTME(input_size=15)

    def forward(self, ZDyn, actionC,prevS_A):
        """
        return critic Q(s,a)
        :param state: state [n, state_dim] (n is batch_size)
        :param action: action [n, action_dim]
        :return: Q(s,a) [n, 1]
        """
        if torch.cuda.is_available():
            allZE=torch.zeros(len(prevS_A),100).cuda()
            for i in range(len(prevS_A)):
                allZE[i,:]=self.lstm(torch.from_numpy(prevS_A[i].reshape(1,-1,15)).float().cuda())
        else:
            ZE = self.lstm(prevS_A)

        Z=torch.cat((allZE, ZDyn.cuda()), dim=1)
        Z = self.relu(self.fc1(Z))
        x = torch.cat((Z, actionC), dim=1)

        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x,Z