import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from rlkit.torch.networks import shared_policy, shared_policy2, shared_policy3
import torch.nn.init as init

import numpy as np

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20

class PPO_Actor(nn.Module):
    def __init__(self,obs_dim,action_dim,hidden_dim,latent_dim):
        super(PPO_Actor,self).__init__()
        self.first_layer = nn.Linear(obs_dim,hidden_dim)
        self.ln = nn.LayerNorm(hidden_dim)
        self.shared_layer = shared_policy2(hidden_dim,latent_dim)
        self.fc_mean = nn.Linear(latent_dim,action_dim)
        self.fc_logstd = nn.Linear(latent_dim,action_dim)
    
    def forward(self, obs):
        latent_state = F.relu(self.ln(self.first_layer(obs)))
        latent_action = self.shared_layer(latent_state)
        mean = self.fc_mean(latent_action)
        log_std = self.fc_logstd(latent_action)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        return mean, log_std

    def select_action(self, obs):           
        with torch.no_grad():
            mean, log_std = self(obs)
            std = torch.exp(log_std)
            normal = Normal(mean, std)
            action = normal.sample()
            log_prob = normal.log_prob(action).sum(-1, keepdim=True)
            action = torch.clamp(action, -1, 1)
            if action.dim() == 2 and action.shape[0] == 1:
                action = action.squeeze(0).cpu().numpy()
        return action, log_prob
    
    def evaluate_actions(self, states, actions):
        mean, log_std = self(states)
        std = torch.exp(log_std)
        normal = Normal(mean, std)
        log_probs = normal.log_prob(actions)
        log_probs = log_probs.sum(dim=-1)
        dist_entropy = normal.entropy().sum(dim=-1)
        return log_probs, dist_entropy
    
class value_function(nn.Module):
    def __init__(self,obs_dim):
        super(value_function,self).__init__()
        self.fc1 = nn.Linear(obs_dim,256)
        self.fc2 = nn.Linear(256,256)
        self.out = nn.Linear(256,1)

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)
    
class q_function(nn.Module):
    def __init__(self, input_dim,hidden_dim):
        super(q_function,self).__init__()
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.out = nn.Linear(hidden_dim,1)
    
    def forward(self,state,action):
        x = torch.cat([state,action], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)
    
class es_policy(nn.Module):
    def __init__(self,obs_dim,action_dim,hidden_dim,latent_dim):
        super(es_policy,self).__init__()
        self.shared_layer = shared_policy(obs_dim,hidden_dim,latent_dim)
        self.last_layer = nn.Linear(latent_dim,action_dim)

    def forward(self,obs):
        with torch.no_grad():
            latent_action = self.shared_layer(obs)
            action = torch.tanh(self.last_layer(latent_action))
        return action

class es_policy2(nn.Module):
    def __init__(self,obs_dim,action_dim,hidden_dim,latent_dim,w):
        super(es_policy2,self).__init__()
        self.first_layer = nn.Linear(obs_dim,hidden_dim)
        self.ln = nn.LayerNorm(hidden_dim)
        self.shared_layer = shared_policy2(hidden_dim,latent_dim)
        self.last_layer = nn.Linear(latent_dim,action_dim)

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        self.w_min = -w
        self.w_max = w

        init.uniform_(self.last_layer.weight, a=self.w_min, b=self.w_max)
        init.uniform_(self.last_layer.bias, a=self.w_min, b=self.w_max)

        # 초기 파라미터 저장
        self.initial_first_weight = self.first_layer.weight.clone().detach()
        self.initial_first_bias = self.first_layer.bias.clone().detach()
        self.initial_last_weight = self.last_layer.weight.clone().detach()
        self.initial_last_bias = self.last_layer.bias.clone().detach()
    
    def reset(self,device):
        with torch.no_grad():
            # Use .data.copy_() for an in-place copy.
            # This modifies the existing parameter tensor's data.
            init.uniform_(self.last_layer.weight, a=self.w_min, b=self.w_max)
            init.uniform_(self.last_layer.bias, a=self.w_min, b=self.w_max)
            
            # self.first_layer.weight.data.copy_(self.initial_first_weight)
            # self.first_layer.bias.data.copy_(self.initial_first_bias)

            self.last_layer.weight.data.copy_(self.initial_last_weight)
            self.last_layer.bias.data.copy_(self.initial_last_bias)

            # self.first_layer.reset_parameters()
            # self.last_layer.reset_parameters()



    def forward(self,obs):
        with torch.no_grad():
            latent_state = F.relu(self.ln(self.first_layer(obs)))
            latent_action = self.shared_layer(latent_state)
            action = torch.tanh(self.last_layer(latent_action))
        return action
    
    def grad_action(self,obs):
        latent_state = F.relu(self.ln(self.first_layer(obs)))
        latent_action = self.shared_layer(latent_state)
        action = torch.tanh(self.last_layer(latent_action))
        return action
    
class es_policy2_nonlinear(nn.Module):
    def __init__(self,obs_dim,action_dim,hidden_dim,latent_dim):
        super(es_policy2_nonlinear,self).__init__()
        self.first_layer = nn.Linear(obs_dim,hidden_dim)
        self.second_layer = nn.Linear(hidden_dim,hidden_dim)
        self.ln = nn.LayerNorm(hidden_dim)
        self.shared_layer = shared_policy2(hidden_dim,latent_dim)
        self.last_layer1 = nn.Linear(latent_dim,latent_dim)
        self.last_layer2 = nn.Linear(latent_dim,action_dim)

    def forward(self,obs):
        with torch.no_grad():
            latent_state = F.relu(self.ln(self.first_layer(obs)))
            latent_state = F.relu(self.second_layer(latent_state))
            latent_action = self.shared_layer(latent_state)
            latent_action = F.relu(self.last_layer1(latent_action))
            action = torch.tanh(self.last_layer2(latent_action))
        return action