import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

################################## set device ##################################
# print("============================================================================================")
# # set device to cpu or cuda
# device = torch.device('cpu')
# if(torch.cuda.is_available()):
#     device = torch.device('cuda:0')
#     torch.cuda.empty_cache()
#     print("Device set to : " + str(torch.cuda.get_device_name(device)))
# else:
#     print("Device set to : cpu")
# print("============================================================================================")


################################## PPO Policy ##################################
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.nextstates = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.nextstates[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]



def init_norm_layer(input_dim, norm_layer):
    if norm_layer == "batchnorm":
        return nn.BatchNorm1d(input_dim, eps=0, momentum=None,
                              affine=False, track_running_stats=False)
    elif norm_layer is None:
        return nn.Identity()


class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init, device, norm_layer=None):
        super(ActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space
        self.device = device
        self.norm_layer = norm_layer

        self.action_dim = action_dim
        # self.action_logstd = torch.ones(action_dim, requires_grad=True) * 0.1
        self.action_logstd = nn.Parameter(-1.0 * torch.ones(action_dim))
        self.action_logstd.to(self.device)

        # actor
        self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            init_norm_layer(64, self.norm_layer),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            init_norm_layer(64, self.norm_layer),
                            nn.Tanh(),
                            nn.Linear(64, action_dim)
                        )

        # critic
        self.critic = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.Tanh(),
                        nn.Linear(64, 64),
                        nn.Tanh(),
                        nn.Linear(64, 1)
                    )

    def forward(self):
        raise NotImplementedError
    
    def act(self, state):
        action_mean = self.actor(state)
        action_std = torch.exp(self.action_logstd)

        dist = torch.distributions.Normal(action_mean, action_std)
        action = dist.sample()
        action_logprob = torch.sum(dist.log_prob(action), -1)

        return action.detach(), action_logprob.detach()

    def get_action(self, state):
        action_mean = self.actor(state)
        action_std = torch.exp(self.action_logstd)

        return action_mean, action_std

    def select_action(self, state):

        with torch.no_grad():
            action_mean = self.actor(state)
            action_std = torch.exp(self.action_logstd)

        return action_mean, action_std

    def v(self, state):
        return self.critic(state)

    def evaluate(self, state, action):

        action_mean = self.actor(state)
        action_std = torch.exp(self.action_logstd)
        dist = torch.distributions.Normal(action_mean, action_std)

        # For Single Action Environments.
        if self.action_dim == 1:
            action = action.reshape(-1, self.action_dim)

        action_logprobs = torch.sum(dist.log_prob(action), -1)
        dist_entropy = torch.sum(dist.entropy(), -1)
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

class PPO:
    def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip,
                 has_continuous_action_space, action_std_init=0.6, device='cpu'):

        self.has_continuous_action_space = has_continuous_action_space

        self.gamma = gamma
        self.gae = False
        self.lambda_ = 0.95
        self.entropy_coef = 0.01
        self.critic_coef = 0.5
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.device = device
        
        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init, device=self.device).to(self.device)
        self.actor_optimizer = torch.optim.Adam([
            {'params': self.policy.actor.parameters(), 'lr': lr_actor}
            # ,{'params': self.policy.action_logstd, 'lr': lr_actor}
        ])
        # self.action_logstd_optimizer = torch.optim.Adam([self.policy.action_logstd], lr=lr_actor)
        self.critic_optimizer = torch.optim.Adam([{'params': self.policy.critic.parameters(), 'lr': lr_critic}])

        self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init, device=self.device).to(self.device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.policy_old.eval()

        self.critic_loss = nn.SmoothL1Loss()
        # self.critic_loss = nn.MSELoss()

    def update(self, time_step, writer):

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(self.device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(self.device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(self.device)
        old_rewards = torch.squeeze(torch.stack(self.buffer.rewards, dim=0)).detach().to(self.device)
        # old_nextstates = torch.squeeze(torch.stack(self.buffer.nextstates, dim=0)).detach().to(self.device)
        old_is_terminals = torch.squeeze(torch.stack(self.buffer.is_terminals, dim=0)).detach().to(self.device)
        num_buffer = len(old_states)

        with torch.no_grad():
            old_values = self.policy.v(old_states).float().squeeze().to(self.device)
            next_value = self.policy.v(self.buffer.nextstates[-1]).to(self.device)
            if self.gae:
                advantages = torch.zeros_like(old_rewards).to(self.device)
                lastgaelam = 0
                for t in reversed(range(num_buffer)):
                    if t == num_buffer - 1:
                        nextnonterminal = 1.0 - torch.tensor([self.buffer.is_terminals[-1]]).to(self.device)
                        nextvalues = next_value
                    else:
                        nextnonterminal = 1.0 - old_is_terminals[t + 1]
                        nextvalues = old_values[t + 1]
                    delta = old_rewards[t] + self.gamma * nextvalues * nextnonterminal - old_values[t]
                    advantages[t] = lastgaelam = delta + self.gamma * self.lambda_ * nextnonterminal * lastgaelam
                returns = advantages + old_values
            else:
                returns = torch.zeros_like(old_rewards).to(self.device)
                for t in reversed(range(num_buffer)):
                    if t == num_buffer - 1:
                        nextnonterminal = 1.0 - torch.tensor([self.buffer.is_terminals[-1]]).to(self.device)
                        next_return = next_value
                    else:
                        nextnonterminal = 1.0 - old_is_terminals[t + 1]
                        next_return = returns[t + 1]
                    returns[t] = old_rewards[t] + self.gamma * nextnonterminal * next_return
                advantages = returns - old_values

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):

            # Evaluating old actions and values
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)

            # Finding Surrogate Loss
            ratio = torch.exp(logprobs.squeeze() - old_logprobs.detach())
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            # final loss of clipped objective PPO (? dist_entropy -> KL divergence)
            actor_loss = (-torch.min(surr1, surr2) - self.entropy_coef * dist_entropy).mean()

            # critic_loss = self.critic_coef * ((state_values - returns).pow(2).mean())
            critic_loss = self.critic_coef * self.critic_loss(state_values, returns)

            writer.add_scalar("actor_loss", actor_loss, time_step)
            writer.add_scalar("critic_loss", critic_loss, time_step)
            writer.add_scalar("action_logstd", torch.exp(self.policy.action_logstd).mean(), time_step)
            writer.flush()

            # take gradient step
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            # nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
            self.actor_optimizer.step()

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            # nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
            self.critic_optimizer.step()
            
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()
    
    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
   
    def load(self, checkpoint_path):
        # self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        
        
       


