import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import os
from copy import deepcopy

from meta_test_algo.network import PPO_Actor, value_function
from meta_test_algo.buffer import RolloutBuffer
from meta_test_algo.base import base


class policy_gradient(base):
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 **kwargs):
        super().__init__(obs_dim,
                         action_dim,
                         net_size,
                         latent_action_dim,
                         device,
                         ppo=True,
                         **kwargs
                         )
        

        self.num_epochs = kwargs['n_epochs']
        self.eps_clip = kwargs['eps_clip']
        self.entropy_coef = kwargs['entropy_coef']
        self.max_grad_norm = kwargs['max_grad_norm']

        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']
        self.discount = kwargs['meta_test_discount']
        self.update_step = kwargs['update_step']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device
        
        self.policy = PPO_Actor(obs_dim,action_dim,net_size,latent_action_dim).to(self.device)
        self.value = value_function(obs_dim).to(self.device)
        self.initial_value_params = deepcopy(self.value.state_dict())

        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=1e-4)
        self.value_optimizer = optim.Adam(self.value.parameters(), lr=0.01)

        self.buffer = RolloutBuffer()
        self.mse_loss = nn.MSELoss()


        self.vf_learning_iters = 100
        self.lamda = 1

    def train_value_function(self):
        states = torch.from_numpy(np.array(self.buffer.states)).float().to(self.device)

        rewards_to_go = self.compute_returns()
        
        # self.value.load_state_dict(self.initial_value_params)

        for _ in range(self.vf_learning_iters):
            self.value_optimizer.zero_grad()
            state_values = self.value(states).squeeze()
            value_loss = self.mse_loss(state_values, rewards_to_go)
            value_loss.backward()
            self.value_optimizer.step()
            print(f"Value Loss: {value_loss.item()}")
        
        with torch.no_grad():
            state_values = self.value(states)

        return state_values.cpu().numpy()

    def compute_returns(self):
        returns = []
        discounted_reward = 0

        for reward, done in zip(reversed(self.buffer.rewards), reversed(self.buffer.dones)):
            if done:
                discounted_reward = 0
            discounted_reward = self.reward_scale*reward + self.discount * discounted_reward
            returns.insert(0, discounted_reward)
        returns = np.array(returns, dtype=np.float32)
        returns = torch.flatten(torch.from_numpy(returns).float()).to(self.device)
        return returns

    def compute_gae(self):
        # 보상합 및 GAE 계산
        rewards_batch = torch.from_numpy(np.array(self.buffer.rewards)).float()
        dones_batch = self.buffer.dones
        values_batch = self.train_value_function()
        advants_batch = torch.zeros_like(rewards_batch)
        prev_value = 0
        running_advant = 0

        for t in reversed(range(len(rewards_batch))):
            # GAE 계산
            running_tderror = (
                rewards_batch[t] + self.discount * (1 - dones_batch[t]) * prev_value - values_batch[t]
            )
            running_advant = (
                running_tderror + self.discount * self.lamda * (1 - dones_batch[t]) * running_advant
            )
            advants_batch[t] = running_advant
            prev_value = values_batch[t]

        advants_batch = advants_batch.to(self.device)
        # 어드밴티지 정규화
        advants_batch = (advants_batch - advants_batch.mean()) / (advants_batch.std() + 1e-8)
        return advants_batch


    def collet_data_and_train_filter(self,env):
        o = env.reset()
        env_step = 0
        episode_return = 0
        while env_step<self.max_path_length:
            env_step += 1
            action, log_prob, value = self.select_action(o)
            next_o, r, done, env_info = env.step(action)
            episode_return += r
            if env_step == self.max_path_length:
                done = True
            self.buffer.store_transition(o,action,log_prob,r,done,value)
            o = next_o
            if done:
                break
        if len(self.buffer.states) // self.update_step >= 1:
            
            self.update_policy()
        return env_step

    

    def update_policy(self):
        states = torch.from_numpy(np.array(self.buffer.states)).float().to(self.device)
        actions = torch.from_numpy(np.array(self.buffer.actions)).float().to(self.device)
        advant_batch = self.compute_gae().detach()

        log_probs, _ = self.policy.evaluate_actions(states,actions)
        actor_loss = -torch.mean(log_probs*advant_batch)
        self.policy_optimizer.zero_grad()
        actor_loss.backward()
        self.policy_optimizer.step()
        self.buffer.clear()