from EscapeEnv.common.base_estimator import ActorCriticEstimator
from EscapeEnv.common.scheduler import ConstantParamScheduler
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim



class PPOEstimator(ActorCriticEstimator):
    def __init__(self, actor_network, critic_network, learning_rate, gamma, optimizer_kwargs, estimator_kwargs, device) -> None:
        super().__init__(actor_network, critic_network, learning_rate, gamma, optimizer_kwargs, estimator_kwargs, device)
        
        self.vf_coef = self.estimator_kwargs['vf_coef']
        self.ent_coef = self.estimator_kwargs['ent_coef']
        self.max_grad_norm = self.estimator_kwargs['max_grad_norm']
        self.use_rms_prop = self.estimator_kwargs['use_rms_prop']
        self.loops_per_train = self.estimator_kwargs['loops_per_train']
        
        self.clip_range = 0.2

        self.parameters = list(actor_network.parameters()) + list(critic_network.parameters())
        if self.use_rms_prop:
            self.optimizer = optim.RMSprop(self.parameters, **self.optimizer_kwargs)
        else:
            self.optimizer = optim.Adam(self.parameters, lr=self.learning_rate)
                
        self.lr_scheduler = ConstantParamScheduler(self.optimizer, 'lr', self.learning_rate)
        self.schedulers = [self.lr_scheduler]
        self.mse_loss = nn.MSELoss(reduction='mean')
        
    
    # def clip_range(self, progress):
    #     return 0.2 + 0.8 * (1 - progress)
    
    def decay_ent_coef(self, progress):
        self.ent_coef = 0.01 * max(0.2 - progress, 0)/0.2
    
    def update(self, buffer, progress):
        for schedule in self.schedulers:
            schedule.step()
        self.decay_ent_coef(progress)
        
        rollout_data = buffer.get()

        for _ in range(self.loops_per_train):
            # batch data
            log_prob, entropy = self.actor_net.evaluate_actions(rollout_data.observations, rollout_data.actions, rollout_data.action_mask)
            values = self.critic_net(rollout_data.observations)
            values = values.flatten()
            
            # advantages = rollout_data.advantages.clone().detach()
            
            # ratio between old and new policy, should be one at the first iteration
            ratio = torch.exp(log_prob - rollout_data.log_probs)
            
            # clipped surrogate loss
            policy_loss_1 = rollout_data.advantages * ratio
            policy_loss_2 = rollout_data.advantages * torch.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
            policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

            value_loss = F.mse_loss(rollout_data.returns, values)
            
            
            entropy_loss = - torch.mean(entropy)
            
            loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
            
            
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters, self.max_grad_norm) 
            self.optimizer.step()
        
        
        self.n_updates += 1
        return loss.item()

if __name__ == '__main__':
    x = torch.tensor([1,2,3,4,5])
    legal = torch.tensor([0,1,0,1,0])
    
    y = (x.exp() * legal).max(dim=-1)[0].log()
    print(y)
    
