from models import ActorCritic
from torch.optim import Adam
import torch
import torch.multiprocessing
import torch.nn.functional as F

import time
import copy


class PPO:
    def __init__(self, env, layer_dims, nr_actions, lr, eps=0.2, v_coeff=0.5,
                 ent_coeff=0.01, max_grad_norm=0.5, dist='Normal', ac_base_type='mlp', vf_clip=1):
        self.lr = lr
        self.env = env
        # get actor critic mlp spec
        ob_shapes = list(env.observation_space.shape)
        ac_shapes = list(env.action_space.shape)
        if not ac_shapes:
            ac_shapes = [1]
        ac_layer_dims = [ob_shapes[-1]] + layer_dims

        self.layer_dims = ac_layer_dims
        self.eps = eps
        self.v_coeff = v_coeff
        self.ent_coeff = ent_coeff
        self.ac = ActorCritic(ac_layer_dims, ob_shapes, ac_shapes, nr_actions, 'pi', trainable=True,
                              dist=dist, base_type=ac_base_type)
        self.old_ac = ActorCritic(ac_layer_dims, ob_shapes, ac_shapes, nr_actions, 'old_pi', trainable=False,
                                  dist=dist, base_type=ac_base_type)
        self.old_value = 0

        self.actor_loss = 0
        self.critic_loss = 0
        self.ent = 1
        self.vf_clip = vf_clip
        self.max_grad_norm = max_grad_norm

        print(self.ac.actor)
        # print(self.ac.get_actor_parameters())
        print(self.ac.critic)
        # print(self.ac.get_critic_parameters())

        self.pi_optimizer = Adam(self.ac.get_actor_parameters(), lr=self.lr)
        self.v_optimizer = Adam(self.ac.get_critic_parameters(), lr=self.lr)

    def get_action(self, obs):
        return self.ac.sample_policy_action(obs)

    def get_lprobs(self, obs, acs):
        return self.ac.lprobs_action(obs, acs)

    def get_value(self, obs):
        return self.ac.get_value(obs)

    def assign_old_pi(self):
        # self.old_pi.load_state_dict(self.pi.state_dict())
        self.old_ac.base.load_state_dict(self.ac.base.state_dict())
        self.old_ac.actor.load_state_dict(self.ac.actor.state_dict())
        self.old_ac.log_std = copy.deepcopy(self.ac.log_std)

        # print("assigning pi")
        # print(self.old_ac.log_std, self.ac.log_std)
        # print([p.data.norm(2) for p in self.ac.actor.parameters()], 
              # [p.data.norm(2) for p in self.old_ac.actor.parameters()])

    def ac_forward(self, obs, acs):
        pi, log_prob, value = self.ac.forward(obs, acs)
        _, old_log_prob, _ = self.old_ac.forward(obs, acs)

        return pi, log_prob, old_log_prob, value

    def calc_loss(self, pi, log_prob, old_log_prob, value, returns, advs):
        ent = torch.mean(pi.entropy())
        ratio = torch.exp(log_prob - old_log_prob.detach())
        # print(log_prob.norm(2), old_log_prob.norm(2))
        # print(ratio)
        # print(advs)

        if len(ratio.shape) != len(advs.shape):
            advs = torch.unsqueeze(advs, -1)
        surr = ratio * advs
        # print(surr)

        actor_loss = torch.mean(torch.min(surr, torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advs))
        value = torch.squeeze(value)
        if len(returns.shape) < len(value.shape):
            returns = torch.unsqueeze(returns, -1)
        if len(value.shape) < len(returns.shape):
            value = torch.unsqueeze(value, -1)
        
        if self.vf_clip > 0:
            value = self.old_value + torch.clamp(
                value - self.old_value, -self.vf_clip, self.vf_clip
            )

        # critic_loss = torch.mean((returns - value) ** 2)
        critic_loss = torch.mean((returns - value) ** 2)
        loss = (- actor_loss - self.ent_coeff * torch.mean(ent) + self.v_coeff * critic_loss)

        self.old_value = value.detach()

        return loss, actor_loss, critic_loss, ent

    def step(self, loss):
        self.pi_optimizer.zero_grad()
        self.v_optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.ac.actor.parameters(), self.max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.ac.critic.parameters(), self.max_grad_norm)
        self.pi_optimizer.step()
        self.v_optimizer.step()

    def update(self, obs, acs, returns, advs):
        pi, log_prob, old_log_prob, value = self.ac_forward(obs, acs)
        # print(pi, log_prob, old_log_prob, value)
        self.loss, self.actor_loss, self.critic_loss, self.ent = self.calc_loss(pi, log_prob,
                                                                                old_log_prob,
                                                                                value,
                                                                                returns,
                                                                                advs)
        
        # print(self.loss, self.actor_loss, self.critic_loss, self.ent)
        self.step(self.loss)

