# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

class actor(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(actor, self).__init__()
        self.h1 = 300
        self.h2 = 200
        self.f1 = nn.Linear(input_dim, self.h1)
        self.r1 = nn.PReLU()
        self.f2 = nn.Linear(self.h1, self.h2)
        self.r2 = nn.PReLU()
        self.f3 = nn.Linear(self.h2, action_dim)
        self.r3 = nn.PReLU()
    def forward(self, x):

        x = self.r1(self.f1(x))
        x = self.r2(self.f2(x))
        return self.r3(self.f3(x))

class critic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(critic, self).__init__()
        self.h1 = 300

        self.h2 = 200
        self.fs1 = nn.Linear(input_dim, self.h1)
        self.r1 = nn.ReLU()
        self.f2 = nn.Linear(self.h1, self.h2)
        self.r2 = nn.ReLU()
        self.f3 = nn.Linear(self.h2, 1)

    def forward(self,state):
        x = self.r1(self.fs1(state))
        x = self.r2(self.f2(x))
        return self.f3(x)


class model():
    def __init__(self,actor_indim, actor_dim, critic_dim):
        self.actor = actor(actor_indim, actor_dim)
        self.critic = critic(critic_dim, actor_dim)

