
import torch
import numpy as np
import torch.nn as nn

def fanin_init(size, fanin=None):
    """
    weight initializer known from https://arxiv.org/abs/1502.01852
    :param size:
    :param fanin:
    :return:
    """
    fanin = fanin or size[0]
    v = 1. / np.sqrt(fanin)
    return torch.Tensor(size).uniform_(-v, v)


class Val_Func(nn.Module):
    def __init__(self, Z_dim, action_dim, h1=50, h2=5, eps=0.03):
        super(Critic, self).__init__()
        torch.manual_seed(0)

        self.Z_dim = Z_dim
        self.action_dim = action_dim

        self.fc1 = nn.Linear(self.Z_dim, h1)
        self.fc1.weight.data = fanin_init(self.fc1.weight.data.size())

        self.fc2 = nn.Linear(h1 + action_dim, h2)
        self.fc2.weight.data = fanin_init(self.fc2.weight.data.size())

        self.fc3 = nn.Linear(h2, 1)
        self.fc3.weight.data.uniform_(-eps, eps)

        self.relu = nn.ReLU()

    def forward(self, Z, action):
        """
        return critic Q(s,a)
        :param state: state [n, state_dim] (n is batch_size)
        :param action: action [n, action_dim]
        :return: Q(s,a) [n, 1]
        """


        Z = self.relu(self.fc1(Z))
        x = torch.cat((Z, action), dim=1)

        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x,Z

import torch
import torch.nn as nn
import torch.nn.functional as F


class LSTMDyn(nn.Module):
    def __init__(self, input_size=1, hidden_layer_size=200, state_size=1):
        super().__init__()
        torch.manual_seed(0)
        self.hidden_layer_size = hidden_layer_size

        self.lstm = nn.LSTM(input_size, hidden_layer_size, batch_first=True)

        self.linearA = nn.Linear(hidden_layer_size, 121)
        self.linearB=nn.Linear(hidden_layer_size,44)

        self.linearStates = nn.Linear(hidden_layer_size, state_size)

    def forward(self, input_seq):
        self.hidden_cell = (torch.zeros(1, input_seq.shape[0], self.hidden_layer_size),
                            torch.zeros(1, input_seq.shape[0], self.hidden_layer_size))
        lstm_out, self.hidden_cell = self.lstm(input_seq, self.hidden_cell)
        predictionA = self.linearA(lstm_out[:,-1,:])
        predictionB=self.linearB(lstm_out[:,-1,:])
        predictionStates = self.linearStates(lstm_out)

        return predictionA,predictionB,predictionStates,lstm_out[:,-1,:]



class Critic(nn.Module):
    def __init__(self, Z_dim, action_dim, h1=50, h2=5, eps=0.03):
        super(Critic, self).__init__()
        torch.manual_seed(0)



        self.Z_dim = Z_dim
        self.action_dim = action_dim

        self.value=Val_Func(Z_dim,action_dim)
        self.lstmDyn=LSTMDyn()


    def forward(self, states,actions,dStates):
        """
        return critic Q(s,a)
        :param state: state [n, state_dim] (n is batch_size)
        :param action: action [n, action_dim]
        :return: Q(s,a) [n, 1]
        """
        seq=torch.cat(states,actions)
        A,B,predS,zDyn=self.lstmDyn(seq)

        states_mean=np.mean(states,axis=0)
        actions_mean = np.mean(actions, axis=0)
        dS_mean = np.mean(dStates, axis=0)

        states_std = np.std(states, axis=0)
        actions_std = np.std(actions, axis=0)
        dS_std= np.std(dStates, axis=0)

        ZPrime=np.concatenate((zDyn,states_mean,actions_mean,dS_mean,states_std,actions_std,dS_std))

        R,Z=self.value(ZPrime,actions)

        return R,Z,A,B



