import torch
import numpy as np
import torch.nn as nn
import copy
import torch.nn.functional as F
from torch.distributions import Normal
from torch import optim
import random
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ReplayBuffer(object):
    def __init__(self, state_dim, action_dim, max_size=int(1e6)):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0

        self.state = np.zeros((max_size, state_dim))
        self.action = np.zeros((max_size, action_dim))
        self.reward = np.zeros(max_size)
        self.next_state = np.zeros((max_size, state_dim))
        self.dead = np.zeros(max_size)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def add(self, state, action, reward, next_state, dead):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.reward[self.ptr] = reward
        self.next_state[self.ptr] = next_state
        self.dead[self.ptr] = dead

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)

        return(torch.FloatTensor(self.state[ind]).to(self.device),
               torch.FloatTensor(self.action[ind]).to(self.device),
               torch.FloatTensor(self.reward[ind]).to(self.device),
               torch.FloatTensor(self.next_state[ind]).to(self.device),
               torch.FloatTensor(self.dead[ind]).to(self.device))


class DeepQnetwork(nn.Module):
    def __init__(self, state_dim, actions, net_width):
        super(DeepQnetwork, self).__init__()
        self.l1 = nn.Linear(state_dim, net_width)
        self.l2 = nn.Linear(net_width, net_width)
        self.l3 = nn.Linear(net_width, actions)

    def forward(self, state):
        q1 = torch.relu(self.l1(state))
        q1 = torch.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

class discrete_actor(nn.Module):
    def __init__(self, state_dim, action_dim, action_slices, net_width):
        super(discrete_actor, self).__init__()
        self.l1 = nn.Linear(state_dim, net_width)
        self.l2 = nn.Linear(net_width, net_width)
        self.l3 = nn.Linear(net_width, action_dim * action_slices)
        self.action_dim = action_dim
        self.action_slices = action_slices
        self.softmax = nn.Softmax(dim=2)

    def forward(self, states):
        out = torch.relu(self.l1(states))
        out = torch.relu(self.l2(out))
        out = self.l3(out)
        out = out.reshape([out.shape[0], self.action_dim, self.action_slices])
        out = self.softmax(out)
        return out

class TD3_Actor(nn.Module):
    def __init__(self, state_dim, action_dim, net_width):
        super(TD3_Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, net_width)
        self.l2 = nn.Linear(net_width, net_width)
        self.l3 = nn.Linear(net_width, action_dim)


    def forward(self, state):
        a = torch.relu(self.l1(state))
        a = torch.relu(self.l2(a))
        a = torch.tanh(self.l3(a))
        return a

class critic(nn.Module):
    def __init__(self, state_dim, action_dim, net_width):
        super(critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, net_width)
        self.l2 = nn.Linear(net_width, net_width)
        self.l3 = nn.Linear(net_width, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, net_width)
        self.l5 = nn.Linear(net_width, net_width)
        self.l6 = nn.Linear(net_width, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], dim=1)
        q1 = torch.relu(self.l1(sa))
        q1 = torch.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = torch.relu(self.l4(sa))
        q2 = torch.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], dim=1)
        q1 = torch.relu(self.l1(sa))
        q1 = torch.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1



















