
from torch.distributions import Categorical
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parallel import DataParallel
import numpy as np

from collections import deque

from torch.distributions import MultivariateNormal

device = torch.device('cpu')
if(torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

class Policy_REINFORCEMENT(nn.Module):
    def __init__(self,input_size,action_num,gamma):
        super(Policy_REINFORCEMENT, self).__init__()
        self.input_size = input_size
        self.action_num = action_num
        self.gamma = gamma
        self.optimizer = None

        self.affine1 = nn.Linear(self.input_size, self.input_size*2)
        self.affine2 = nn.Linear(self.input_size*2, self.input_size * 3)
        self.affine3 = nn.Linear(self.input_size * 3, self.input_size * 2)
        self.affine4 = nn.Linear(self.input_size*2, self.action_num)
        self.dn1 = nn.Dropout(0.5)
        self.dn2 = nn.Dropout(0.5)
        self.dn3 = nn.Dropout(0.5)
        self.bn1 = nn.BatchNorm1d(self.input_size*2)
        self.bn2 = nn.BatchNorm1d(self.input_size * 3)
        self.bn3 = nn.BatchNorm1d(self.input_size * 2)

        self.saved_log_probs = []
        self.rewards = []
        self.dones = []

        self.eps = np.finfo(np.float32).eps.item()

    def forward(self, x):

        x = x.reshape(-1,self.input_size)

        x = self.affine1(x)
        x = self.dn1(x)
        #x = self.bn1(x)
        x = F.relu(x)

        x = self.affine2(x)
        x = self.dn2(x)
        #x = self.bn2(x)
        x = F.relu(x)

        x = self.affine3(x)
        x = self.dn3(x)
        #x = self.bn3(x)
        x = F.relu(x)

        action_scores = self.affine4(x)
        probs = F.softmax(action_scores,dim=-1)


        return probs

    def sample_action(self,state):

        #output nodes : probabilities of actions on given state
        probs = self.forward(state)

        #action index random sampling based on the probabilities
        m = Categorical(probs)
        action = m.sample()

        #m.log_prob(action) == probability given action
        self.saved_log_probs.append(m.log_prob(action))

        #return scalar, not tensor for stopping gradient backpropagation
        return action.item()

    def put_reward(self,reward):
        self.rewards.append(reward)

    def put_done(self,done):
        self.dones.append(done)

    def train_net(self):
        R = 0
        policy_loss = []
        returns = deque()

        for r in self.rewards[::-1]:
            R = r + self.gamma * R
            returns.appendleft(R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + self.eps)

        #collect loss on an episode(trajectory)
        for log_prob, R in zip(self.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)

        #backward
        self.optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        self.optimizer.step()

        #terminate An episode
        self.saved_log_probs = []
        self.rewards = []
        self.dones = []


class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]



class ActorCritic_MLP(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic_MLP, self).__init__()

        self.actor = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, action_dim),
                nn.Softmax(dim=-1)
            )

        # critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def act(self, state):

        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        state_val = self.critic(state)

        return action.detach(), action_logprob.detach(), state_val.detach()

    def evaluate(self, state, action):

        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

    def forward(self):
        raise NotImplementedError

class PPO_MLP:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip):


        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.buffer = RolloutBuffer()

        self.policy = ActorCritic_MLP(state_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.policy_old = ActorCritic_MLP(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()

    def sample_action(self, state):

        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.float).to(device)
            action, action_logprob, state_val = self.policy_old.act(state)

        self.buffer.states.append(state)
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)

        self.buffer.state_values.append(state_val)

        return action.item()

    def train_net(self):
        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        # calculate advantages
        advantages = rewards.detach() - old_state_values.detach()

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):
            # Evaluating old actions and values
            #old_states.size() = (batch_size,sequence, hidden dimension)
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)

            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()

    def put_reward(self,reward):
        self.buffer.rewards.append(reward)

    def put_done(self,done):
        self.buffer.is_terminals.append(done)

class StateLSTM(nn.Module):

    def __init__(self,input_size,batch_size):
        super(StateLSTM, self).__init__()

        self.lstm = torch.nn.LSTM(input_size = input_size,
                                  hidden_size = input_size,
                                  batch_first=True,
                                  dropout=0.5,
                                  bidirectional=False,
                                  num_layers=2)

        if self.lstm.bidirectional == True:
            self.direction_num = 2
        else:
            self.direction_num = 1

        self.h0_for_sampling_action = torch.zeros(self.lstm.num_layers * self.direction_num, self.lstm.hidden_size, dtype=torch.float).to(device)
        self.c0_for_sampling_action = torch.zeros(self.lstm.num_layers * self.direction_num, self.lstm.hidden_size,  dtype=torch.float).to(device)

        self.node_num = self.lstm.hidden_size * 2 * self.lstm.num_layers * self.direction_num

        """
        self.fc1 = nn.Linear(node_num, node_num * 2)
        self.fc2 = nn.Linear(node_num * 2, node_num * 3)
        self.fc3 = nn.Linear(node_num * 3, node_num * 4)
        self.fc4 = nn.Linear(node_num * 4, ACTION_NUM)

        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        """
    def sample_action_forward(self, x):

        # input : matrix = (state on each storage)
        # output : matrix = (q_s_a for a in A)
        x = self.lstm(x, (self.h0_for_sampling_action, self.c0_for_sampling_action))
        h = x[1][0]
        c = x[1][1]
        cated = torch.cat([h, c], dim=0)
        state = torch.cat([batch_dim for batch_dim in cated])

        return state
    def train_forward(self, x):

        # input : matrix = (batch,state on each storage)
        # output : matrix = (batch, q_s_a for a in A)

        h0 = torch.zeros(self.lstm.num_layers * self.direction_num, x.size(0), self.lstm.hidden_size,
                              dtype=torch.float).to(device)
        c0 = torch.zeros(self.lstm.num_layers * self.direction_num, x.size(0), self.lstm.hidden_size,
                              dtype=torch.float).to(device)


        x = self.lstm(x, (h0, c0))
        h = x[1][0]
        c = x[1][1]
        cated = torch.cat([h, c], dim=0)
        state = torch.cat([batch_dim for batch_dim in cated], dim=1)

        return state



class ActorCritic_LSTM(nn.Module):

    def __init__(self, state_dim, action_dim, lstm_batch_size):
        super(ActorCritic_LSTM, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.lstm_batch_size = lstm_batch_size

        #state info encoder
        self.StateLSTM = StateLSTM(self.state_dim, self.lstm_batch_size)
        self.node_num_after_LSTM = self.StateLSTM.node_num

        #actor
        self.actor = nn.Sequential(
                nn.Linear(self.node_num_after_LSTM, self.node_num_after_LSTM),
                nn.Tanh(),
                nn.Linear(self.node_num_after_LSTM, self.node_num_after_LSTM),
                nn.Tanh(),
                nn.Linear(self.node_num_after_LSTM, action_dim),
                nn.Softmax(dim=-1)
            )

        # critic
        self.critic = nn.Sequential(
            nn.Linear(self.node_num_after_LSTM, self.node_num_after_LSTM),
            nn.Tanh(),
            nn.Linear(self.node_num_after_LSTM, self.node_num_after_LSTM),
            nn.Tanh(),
            nn.Linear(self.node_num_after_LSTM, 1)
        )

    def act(self, state):

        #LSTM
        state = self.StateLSTM.sample_action_forward(state)

        #ACTOR
        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        #CRITIC
        state_val = self.critic(state)

        return action.detach(), action_logprob.detach(), state_val.detach()

    def evaluate(self, state, action):
        #statet = (batch_size, sequence, hidden size)
        state = self.StateLSTM.train_forward(state)

        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

    def forward(self):
        raise NotImplementedError

class PPO_LSTM:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, update_sample_num):


        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.lstm_batch_size = update_sample_num

        self.buffer = RolloutBuffer()

        self.policy = ActorCritic_LSTM(state_dim, action_dim, self.lstm_batch_size).to(device)
        self.optimizer = torch.optim.Adam([
            {'params': self.policy.StateLSTM.parameters(), 'lr': lr_actor},
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.policy_old = ActorCritic_LSTM(state_dim, action_dim, self.lstm_batch_size).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()

    def sample_action(self, state):

        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.float).to(device)
            action, action_logprob, state_val = self.policy_old.act(state)

        self.buffer.states.append(state)
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)

        self.buffer.state_values.append(state_val)

        return action.item()

    def train_net(self):
        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        # calculate advantages
        advantages = rewards.detach() - old_state_values.detach()

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):
            # Evaluating old actions and values
            #old_states.size() = (batch_size,sequence, hidden dimension)
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)

            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()

    def put_reward(self,reward):
        self.buffer.rewards.append(reward)


    def put_done(self,done):
        self.buffer.is_terminals.append(done)


class Actors(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actors, self).__init__()
        
        self.sub_actor1 = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, action_dim),
                nn.Softmax(dim=-1)
            )
        self.sub_actor2 = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, action_dim),
                nn.Softmax(dim=-1)
            )
        self.sub_actor3 = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, action_dim),
                nn.Softmax(dim=-1)
            )
        self.sub_actor4 = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, action_dim),
                nn.Softmax(dim=-1)
            )
    def forward(self,x):
        out1 = self.sub_actor1(x)
        out2 = self.sub_actor2(x)
        out3 = self.sub_actor3(x)
        out4 = self.sub_actor4(x)
        return (out1+out2+out3+out4)/4



class Critics(nn.Module):
    def __init__(self, state_dim):
        super(Critics, self).__init__()

        self.sub_critic1 = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, 1)
            )
        self.sub_critic2 = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, 1)
            )
        self.sub_critic3 = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, 1)
            )
        self.sub_critic4 = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 64),
                nn.Tanh(),
                nn.Linear(64, 1)
            )

    def forward(self,x):
        out1 = self.sub_critic1(x)
        out2 = self.sub_critic2(x)
        out3 = self.sub_critic3(x)
        out4 = self.sub_critic4(x)
        return (out1+out2+out3+out4)/4


class Ensembled_ActorCritic_MLP(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Ensembled_ActorCritic_MLP, self).__init__()

        self.actor = Actors(state_dim, action_dim)
        self.critic = Critics(state_dim)

    def act(self, state):

        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        action_logprob = dist.log_prob(action)

        state_val = self.critic(state)

        return action.detach(), action_logprob.detach(), state_val.detach()

    def evaluate(self, state, action):

        action_probs = self.actor(state)
        dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

    def forward(self):
        raise NotImplementedError


class PPO_EnsembledMLP:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, GPUparallel ):


        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.GPUparallel = GPUparallel
        self.buffer = RolloutBuffer()

        self.policy = Ensembled_ActorCritic_MLP(state_dim, action_dim).to(device)

        self.optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor},
            {'params': self.policy.critic.parameters(), 'lr': lr_critic}
        ])

        self.policy_old = Ensembled_ActorCritic_MLP(state_dim, action_dim).to(device)

        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()

        if GPUparallel : 
            self.policy = DataParallel(self.policy)
        if GPUparallel : 
            self.policy_old = DataParallel(self.policy_old)

    def sample_action(self, state):

        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.float).to(device)

            if self.GPUparallel :
                action, action_logprob, state_val = self.policy_old.module.act(state)
            else:    
                action, action_logprob, state_val = self.policy_old.act(state)

        self.buffer.states.append(state)
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)

        self.buffer.state_values.append(state_val)

        return action.item()

    def train_net(self):
        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        # calculate advantages
        advantages = rewards.detach() - old_state_values.detach()

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):
            # Evaluating old actions and values
            #old_states.size() = (batch_size,sequence, hidden dimension)
            if self.GPUparallel:
                logprobs, state_values, dist_entropy = self.policy.module.evaluate(old_states, old_actions)
            else:
                logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)

            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()

    def put_reward(self,reward):
        self.buffer.rewards.append(reward)

    def put_done(self,done):
        self.buffer.is_terminals.append(done)


