import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import os

from meta_test_algo.network import PPO_Actor, value_function
from meta_test_algo.buffer import RolloutBuffer
from meta_test_algo.base import base


class PPOAgent(base):
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 **kwargs):
        super().__init__(obs_dim,
                         action_dim,
                         net_size,
                         latent_action_dim,
                         device,
                         ppo=True,
                         **kwargs
                         )
        

        self.num_epochs = kwargs['n_epochs']
        self.eps_clip = kwargs['eps_clip']
        self.entropy_coef = kwargs['entropy_coef']
        self.max_grad_norm = kwargs['max_grad_norm']

        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']
        self.discount = kwargs['meta_test_discount']
        self.update_step = kwargs['update_step']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device
        
        self.policy = PPO_Actor(obs_dim,action_dim,net_size,latent_action_dim).to(self.device)
        self.value = value_function(obs_dim).to(self.device)

        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=3e-6)
        self.value_optimizer = optim.Adam(self.value.parameters(), lr=1e-6)

        self.buffer = RolloutBuffer()
        self.mse_loss = nn.MSELoss()


    def collet_data_and_train_filter(self,env):
        o = env.reset()
        env_step = 0
        episode_return = 0
        while env_step<self.max_path_length:
            env_step += 1
            action, log_prob, value = self.select_action(o)
            next_o, r, done, env_info = env.step(action)
            episode_return += r
            if env_step == self.max_path_length:
                done = True
            self.buffer.store_transition(o,action,log_prob,r,done,value)
            o = next_o
            if done:
                break
        if len(self.buffer.states) // self.update_step >= 1:
            self.update_policy()
        return env_step

    def compute_returns(self):
        returns = []
        discounted_reward = 0

        for reward, done in zip(reversed(self.buffer.rewards), reversed(self.buffer.dones)):
            if done:
                discounted_reward = 0
            discounted_reward = self.reward_scale*reward + self.discount * discounted_reward
            returns.insert(0, discounted_reward)
        returns = np.array(returns, dtype=np.float32)
        returns = torch.flatten(torch.from_numpy(returns).float()).to(self.device)
        return returns


    def update_policy(self):
        rewards_to_go = self.compute_returns()
        states = torch.from_numpy(np.array(self.buffer.states)).float().to(self.device)
        actions = torch.from_numpy(np.array(self.buffer.actions)).float().to(self.device)
        old_logprobs = torch.from_numpy(np.array(self.buffer.logprobs)).float().to(self.device)
        state_vals = torch.from_numpy(np.array(self.buffer.state_values)).float().to(self.device)

        advantages = rewards_to_go - state_vals
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-6)
        for _ in range(self.num_epochs):
            logprobs, dist_entropy = self.policy.evaluate_actions(states,actions)
            state_values = self.value(states).squeeze()
            ratios = torch.exp(logprobs - old_logprobs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip,1+self.eps_clip) * advantages

            actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coef * dist_entropy.mean()
            value_loss = 0.5 * self.mse_loss(state_values, rewards_to_go)

            self.policy_optimizer.zero_grad()
            actor_loss.backward()
            for name, param in self.policy.named_parameters():
                if param.grad is not None:
                    print(f"Gradient for {name}: {param.grad.norm()}")
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=self.max_grad_norm)
            self.policy_optimizer.step()

            self.value_optimizer.zero_grad()
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.value.parameters(), max_norm=self.max_grad_norm)
            self.value_optimizer.step()
        print(value_loss.item())
        self.buffer.clear()