"""
General networks for pytorch.

Algorithm-specific networks should go else-where.
"""
import copy
import torch
import numpy as np
from torch import nn as nn
from torch.nn import functional as F
from torch.distributions import Normal

from rlkit.policies.base import Policy
from rlkit.torch import pytorch_util as ptu
from rlkit.torch.core import PyTorchModule
from rlkit.torch.data_management.normalizer import TorchFixedNormalizer
from rlkit.torch.modules import LayerNorm


def identity(x):
    return x


class Mlp(PyTorchModule):
    def __init__(
            self,
            hidden_sizes,
            output_size,
            input_size,
            init_w=3e-3,
            hidden_activation=F.relu,
            output_activation=identity,
            hidden_init=ptu.fanin_init,
            b_init_value=0.1,
            layer_norm=False,
            layer_norm_kwargs=None,
    ):
        self.save_init_params(locals())
        super().__init__()

        if layer_norm_kwargs is None:
            layer_norm_kwargs = dict()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.hidden_activation = hidden_activation
        self.output_activation = output_activation
        self.layer_norm = layer_norm
        self.fcs = []
        self.layer_norms = []
        in_size = input_size

        for i, next_size in enumerate(hidden_sizes):
            fc = nn.Linear(in_size, next_size)
            in_size = next_size
            hidden_init(fc.weight)
            fc.bias.data.fill_(b_init_value)
            self.__setattr__("fc{}".format(i), fc)
            self.fcs.append(fc)

            if self.layer_norm:
                ln = LayerNorm(next_size)
                self.__setattr__("layer_norm{}".format(i), ln)
                self.layer_norms.append(ln)

        self.last_fc = nn.Linear(in_size, output_size)
        self.last_fc.weight.data.uniform_(-init_w, init_w)
        self.last_fc.bias.data.uniform_(-init_w, init_w)

    def forward(self, input, return_preactivations=False):
        h = input
        for i, fc in enumerate(self.fcs):
            h = fc(h)
            if self.layer_norm and i < len(self.fcs) - 1:
                h = self.layer_norms[i](h)
            h = self.hidden_activation(h)
        preactivation = self.last_fc(h)
        output = self.output_activation(preactivation)
        if return_preactivations:
            return output, preactivation
        else:
            return output


class FlattenMlp(Mlp):
    """
    if there are multiple inputs, concatenate along dim 1
    """

    def forward(self, *inputs, **kwargs):
        flat_inputs = torch.cat(inputs, dim=1)
        return super().forward(flat_inputs, **kwargs)


class MlpPolicy(Mlp, Policy):
    """
    A simpler interface for creating policies.
    """

    def __init__(
            self,
            *args,
            obs_normalizer: TorchFixedNormalizer = None,
            **kwargs
    ):
        self.save_init_params(locals())
        super().__init__(*args, **kwargs)
        self.obs_normalizer = obs_normalizer

    def forward(self, obs, **kwargs):
        if self.obs_normalizer:
            obs = self.obs_normalizer.normalize(obs)
        return super().forward(obs, **kwargs)

    def get_action(self, obs_np):
        actions = self.get_actions(obs_np[None])
        return actions[0, :], {}

    def get_actions(self, obs):
        return self.eval_np(obs)


class TanhMlpPolicy(MlpPolicy):
    """
    A helper class since most policies have a tanh output activation.
    """
    def __init__(self, *args, **kwargs):
        self.save_init_params(locals())
        super().__init__(*args, output_activation=torch.tanh, **kwargs)


class MlpEncoder(FlattenMlp):
    '''
    encode context via MLP
    '''

    def reset(self, num_tasks=1):
        pass


LOG_SIG_MAX = 2
LOG_SIG_MIN = -20

class shared_policy(nn.Module):
    def __init__(self,input_dim: int, hidden_dim: int, latent_dim: int):
        super(shared_policy,self).__init__()
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)

        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

        self.fc3 = nn.Linear(hidden_dim,hidden_dim)
        self.ln3 = nn.LayerNorm(hidden_dim)

        self.fc4 = nn.Linear(hidden_dim,latent_dim)
        self.ln4 = nn.LayerNorm(latent_dim)

    def forward(self, x):
        x = F.relu(self.ln1(self.fc1(x)))
        x = F.relu(self.ln2(self.fc2(x)))
        x = F.relu(self.ln3(self.fc3(x)))
        latent_action = F.relu(self.ln4(self.fc4(x)))
        return latent_action
    

class dynamics_filter(nn.Module):
    def __init__(self,action_dim,latent_dim):
        super(dynamics_filter,self).__init__()
        self.filter = nn.Linear(latent_dim+action_dim,latent_dim)

    def forward(self,latent_action,action):
        x = torch.cat([latent_action,action], 1)
        next_latent = F.relu(self.filter(x))
        return next_latent


class soft_task_filter(nn.Module):
    def __init__(self, latent_dim: int, output_dim: int):
        super(soft_task_filter, self).__init__()
        self.fc_mean = nn.Linear(latent_dim,output_dim)
        self.fc_logstd = nn.Linear(latent_dim,output_dim)

    def forward(self,latent_action):
        mean = self.fc_mean(latent_action)
        log_std = self.fc_logstd(latent_action)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)
        normal = Normal(mean, std)
        action = normal.rsample()
        log_prob = normal.log_prob(action)
        log_prob -= 2 * (np.log(2) - action - F.softplus(-2 * action))
        log_prob = log_prob.sum(-1, keepdim=True)
        action = torch.tanh(action)
        mean = torch.tanh(mean)
        return action, log_prob, mean
    
class stochastic_actor(nn.Module):
    def __init__(self,obs_dim,action_dim,hidden_dim,latent_dim):
        super(stochastic_actor,self).__init__()
        self.shared_layer = shared_policy(obs_dim,hidden_dim,latent_dim)
        self.fc_mean = nn.Linear(latent_dim,action_dim)
        self.fc_logstd = nn.Linear(latent_dim,action_dim)
        self.shared_layer.requires_grad_(False)

        nn.init.orthogonal_(self.fc_mean.weight, gain=0.01)
        nn.init.constant_(self.fc_mean.bias, 0.0)
        
        nn.init.orthogonal_(self.fc_logstd.weight, gain=0.01)
        nn.init.constant_(self.fc_logstd.bias, 0.0)
    
    def forward(self, obs):
        latent_action = self.shared_layer(obs)
        mean = self.fc_mean(latent_action)
        log_std = self.fc_logstd(latent_action)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        return mean, log_std

    def select_action(self,obs,deterministic=False):           
        with torch.no_grad():
            mean, log_std = self(torch.Tensor(obs).to(next(self.shared_layer.parameters()).device))
            std = torch.exp(log_std)
            normal = Normal(mean, std)
            action = normal.sample()
            log_prob = normal.log_prob(action)
            log_prob -= 2 * (np.log(2) - action - F.softplus(-2 * action))
            log_prob = log_prob.sum(-1, keepdim=True)
            action = torch.tanh(action)
        if deterministic:
            mean = torch.tanh(mean)
            return mean.cpu().numpy()
        return action.cpu().numpy(), log_prob
    
    def action(self,obs):
        mean, log_std = self(obs)
        std = torch.exp(log_std)
        normal = Normal(mean, std)
        action = normal.rsample()
        log_prob = normal.log_prob(action)
        log_prob -= 2 * (np.log(2) - action - F.softplus(-2 * action))
        log_prob = log_prob.sum(-1, keepdim=True)
        action = torch.tanh(action)
        return action, log_prob
    
    def detached_task_filterd_action(self,latent_action):
        target_fc_mean = copy.deepcopy(self.fc_mean)
        target_fc_logstd = copy.deepcopy(self.fc_logstd)
        target_fc_mean.requires_grad_(False)
        target_fc_logstd.requires_grad_(False)
        mean = target_fc_mean(latent_action)
        log_std = target_fc_logstd(latent_action)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)
        normal = Normal(mean,std)
        action = normal.rsample()
        log_prob = normal.log_prob(action).sum(dim=-1,keepdims=True)
        action = torch.tanh(action)
        return action, log_prob
    

class shared_policy2(nn.Module):
    def __init__(self, hidden_dim: int, latent_dim: int):
        super(shared_policy2,self).__init__()
        self.fc1 = nn.Linear(hidden_dim,hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)

        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

        self.fc3 = nn.Linear(hidden_dim,latent_dim)
        self.ln3 = nn.LayerNorm(latent_dim)

    def forward(self, x):
        x = F.relu(self.ln1(self.fc1(x)))
        x = F.relu(self.ln2(self.fc2(x)))
        latent_action = F.relu(self.ln3(self.fc3(x)))
        return latent_action
    
class shared_policy3(nn.Module):
    def __init__(self, hidden_dim: int, latent_dim: int):
        super(shared_policy3,self).__init__()
        self.fc1 = nn.Linear(hidden_dim,hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)
        
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

        self.fc3 = nn.Linear(hidden_dim,hidden_dim)
        self.ln3 = nn.LayerNorm(hidden_dim)

        self.fc4 = nn.Linear(hidden_dim,hidden_dim)
        self.ln4 = nn.LayerNorm(hidden_dim)

        self.fc5 = nn.Linear(hidden_dim,hidden_dim)
        self.ln5 = nn.LayerNorm(hidden_dim)

        self.fc6 = nn.Linear(hidden_dim,latent_dim)
        self.ln6 = nn.LayerNorm(latent_dim)


    def forward(self, x):
        x = F.relu(self.ln1(self.fc1(x)))
        x = F.relu(self.ln2(self.fc2(x)))
        x = F.relu(self.ln3(self.fc3(x)))
        x = F.relu(self.ln4(self.fc4(x)))
        x = F.relu(self.ln5(self.fc5(x)))
        latent_action = F.relu(self.ln6(self.fc6(x)))
        return latent_action


class stochastic_actor2(nn.Module):
    def __init__(self,obs_dim,action_dim,hidden_dim,latent_dim):
        super(stochastic_actor2,self).__init__()
        self.fc_head = nn.Linear(obs_dim,hidden_dim)
        self.shared_layer = shared_policy2(hidden_dim,latent_dim)
        self.fc_mean = nn.Linear(latent_dim,action_dim)
        self.fc_logstd = nn.Linear(latent_dim,action_dim)
        self.shared_layer.requires_grad_(False)

        nn.init.orthogonal_(self.fc_head.weight, gain=0.01)
        nn.init.constant_(self.fc_head.bias, 0.0)

        nn.init.orthogonal_(self.fc_mean.weight, gain=0.01)
        nn.init.constant_(self.fc_mean.bias, 0.0)
        
        nn.init.orthogonal_(self.fc_logstd.weight, gain=0.01)
        nn.init.constant_(self.fc_logstd.bias, 0.0)
    
    def forward(self, obs):
        latent_state = F.relu(self.fc_head(obs))
        latent_action = self.shared_layer(latent_state)
        mean = self.fc_mean(latent_action)
        log_std = self.fc_logstd(latent_action)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        return mean, log_std

    def select_action(self,obs,deterministic=False):           
        with torch.no_grad():
            mean, log_std = self(torch.Tensor(obs).to(next(self.shared_layer.parameters()).device))
            std = torch.exp(log_std)
            normal = Normal(mean, std)
            action = normal.sample()
            log_prob = normal.log_prob(action)
            log_prob -= 2 * (np.log(2) - action - F.softplus(-2 * action))
            log_prob = log_prob.sum(-1, keepdim=True)
            action = torch.tanh(action)
        if deterministic:
            mean = torch.tanh(mean)
            return mean.cpu().numpy()
        return action.cpu().numpy(), log_prob
    
    def action(self,obs,deterministic=False):
        mean, log_std = self(obs)
        std = torch.exp(log_std)
        normal = Normal(mean, std)
        if deterministic:
            action = torch.tanh(mean)
            return action
        action = normal.rsample()
        log_prob = normal.log_prob(action)
        log_prob -= 2 * (np.log(2) - action - F.softplus(-2 * action))
        log_prob = log_prob.sum(-1, keepdim=True)
        action = torch.tanh(action)
        return action, log_prob
    
    def eval_action(self,obs,action):
        mean, log_std = self(obs)
        std = torch.exp(log_std)
        normal = Normal(mean,std)
        epsilon = 1e-6
        action = torch.clamp(action, -1 + epsilon, 1 - epsilon)
        action = torch.atanh(action)
        log_prob = normal.log_prob(action)
        log_prob -= 2 * (np.log(2) - action - F.softplus(-2 * action))
        log_prob = log_prob.sum(-1, keepdim=True)
        return log_prob
    
    def detached_task_filterd_action(self,latent_action):
        target_fc_mean = copy.deepcopy(self.fc_mean)
        target_fc_logstd = copy.deepcopy(self.fc_logstd)
        target_fc_mean.requires_grad_(False)
        target_fc_logstd.requires_grad_(False)
        mean = target_fc_mean(latent_action)
        log_std = target_fc_logstd(latent_action)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)
        normal = Normal(mean,std)
        action = normal.rsample()
        log_prob = normal.log_prob(action).sum(dim=-1,keepdims=True)
        action = torch.tanh(action)
        return action, log_prob
    
    def detached_task_filtered_state(self,obs):
        target_fc_head = copy.deepcopy(self.fc_head)
        target_fc_head.requires_grad_(False)
        latent_state = F.relu(target_fc_head(obs))
        return latent_state
    

class critic(nn.Module):
    def __init__(self, input_dim,hidden_dim):
        super(critic,self).__init__()
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,hidden_dim)
        self.out = nn.Linear(hidden_dim,1)
    
    def forward(self,state,action):
        x = torch.cat([state,action], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.out(x)
    
class shared_critic(nn.Module):
    def __init__(self, input_dim,hidden_dim):
        super(shared_critic,self).__init__()
        self.critic_filter = nn.Linear(input_dim,hidden_dim)
        self.shared_critic_backbone = shared_critic_backbone(hidden_dim)
    
    def forward(self,state,action):
        x = torch.cat([state,action], 1)
        latent_vector = self.critic_filter(x)
        value = self.shared_critic_backbone(latent_vector)
        return value
    
    def filter(self,state,action):
        x = torch.cat([state,action], 1)
        latent_vector = self.critic_filter(x)
        return latent_vector
    
class shared_critic_backbone(nn.Module):
    def __init__(self, hidden_dim):
        super(shared_critic_backbone,self).__init__()
        self.fc1 = nn.Linear(hidden_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.out = nn.Linear(hidden_dim,1)
    
    def forward(self,latent_vector):
        x = F.relu(latent_vector)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)
    

# Shared Vector Field
class SVF(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super(SVF,self).__init__()
        self.fc1 = nn.Linear(input_dim,hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)

        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,hidden_dim)
        self.fc4 = nn.Linear(hidden_dim,output_dim)

        self.std_min = 0.0

    def forward(self, t, xt, obs):
        x = torch.cat([t,xt,obs],dim=1)
        x = F.relu(self.ln1(self.fc1(x)))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        latent_action = F.relu(self.fc4(x))
        return latent_action
    
    def flow(self,t,x0,action):
        'x1 = action'
        xt = (1-(1-self.std_min)*t)*x0 + t*action
        return xt
    
    def d_flow(self,action):
        'x1 = action'
        if action.dim() == 1:
            x0 = torch.randn(1,action.shape[0]).to(action.device)
        else:
            x0 = torch.randn(action.shape[0],action.shape[1]).to(action.device)
        df = action - (1-self.std_min)*x0
        return df


class SVF_Policy(nn.Module):
    def __init__(self,obs_dim,action_dim,hidden_dim,latent_dim):
        super(SVF_Policy,self).__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.first_layer = nn.Linear(obs_dim,hidden_dim)
        self.svf = SVF(
            input_dim=hidden_dim+action_dim+1,
            hidden_dim=hidden_dim,
            output_dim=latent_dim)
        self.last_layer = nn.Linear(latent_dim,action_dim)

        # SVF Policy로 initialize시, svf는 학습하지 않도록 설정
        self.svf.requires_grad_(False)

    def forward(self,obs,action):
        if action.dim() == 1:
            t = torch.rand(1).to(action.device)
            x0 = torch.randn(1,action.shape[0]).to(action.device)
        else:
            t = torch.rand(action.shape[0],1).to(action.device)
            x0 = torch.randn(action.shape[0],action.shape[1]).to(action.device)
        latent_state = F.relu(self.first_layer(obs))
        xt = self.svf.flow(t,x0,action)
        latent_action = self.svf(t,xt,latent_state)
        action = torch.tanh(self.last_layer(latent_action))

        return action
    
    def detached_task_filtered_state(self,obs):
        target_firtst_layer = copy.deepcopy(self.first_layer)
        target_firtst_layer.requires_grad_(False)
        latent_state = F.relu(target_firtst_layer(obs))

        return latent_state
    
    def detached_task_filtered_action(self,latent_action):
        target_last_layer = copy.deepcopy(self.last_layer)
        target_last_layer.requires_grad_(False)
        action = torch.tanh(target_last_layer(latent_action))

        return action
    
    def euler_integration(self,latent_state,steps=1000):
        latent_state = latent_state.unsqueeze(dim=0)
        if latent_state.dim()==1:
            x0 = torch.randn(1,self.action_dim).to(latent_state.device)
        else:
            x0 = torch.randn(latent_state.shape[0],self.action_dim).to(latent_state.device)
        t_vals = np.linspace(0,1,steps+1)
        dt = 1.0 / steps
        xt = x0
        for t in t_vals[:-1]:
            t_tensor = torch.tensor([[t]], dtype=torch.float32).to(latent_state.device)
            xt = xt + dt*torch.tanh(self.last_layer(self.svf(t_tensor,xt,latent_state)))
        x1 = xt
        x1 = torch.clamp(x1, -1, 1)  # Ensure actions are within bounds
        return x1

        
