
import torch
from torch import nn
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
import numpy as np
import torch.nn.functional as F
import numbers
import math
from typing import Any


##########################
# Initialization utils
##########################

# Initialization for parallel layers
def parallel_orthogonal_(tensor, gain=1):
    if tensor.ndimension() == 2:
        tensor = nn.init.orthogonal_(tensor, gain=gain)
        return tensor
    if tensor.ndimension() < 3:
        raise ValueError("Only tensors with 3 or more dimensions are supported")
    n_parallel = tensor.size(0)
    rows = tensor.size(1)
    cols = tensor.numel() // n_parallel // rows
    flattened = tensor.new(n_parallel, rows, cols).normal_(0, 1)

    qs = []
    for flat_tensor in torch.unbind(flattened, dim=0):
        if rows < cols:
            flat_tensor.t_()

        # Compute the qr factorization
        q, r = torch.linalg.qr(flat_tensor)
        # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
        d = torch.diag(r, 0)
        ph = d.sign()
        q *= ph

        if rows < cols:
            q.t_()
        qs.append(q)

    qs = torch.stack(qs, dim=0)
    with torch.no_grad():
        tensor.view_as(qs).copy_(qs)
        tensor.mul_(gain)
    return tensor

def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)
    elif isinstance(m, DenseParallel):
        gain = nn.init.calculate_gain("relu")
        parallel_orthogonal_(m.weight.data, gain)
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)
    elif hasattr(m, "reset_parameters"):
        m.reset_parameters()


##########################
# Update utils
##########################

def _soft_update_params(net_params: Any, target_net_params: Any, tau: float):
    torch._foreach_mul_(target_net_params, 1 - tau)
    torch._foreach_add_(target_net_params, net_params, alpha=tau)

def soft_update_params(net, target_net, tau) -> None:
    tau = float(min(max(tau, 0), 1))
    net_params = tuple(x.data for x in net.parameters())
    target_net_params = tuple(x.data for x in target_net.parameters())
    _soft_update_params(net_params, target_net_params, tau)

class eval_mode:
    def __init__(self, *models) -> None:
        self.models = models
        self.prev_states = []

    def __enter__(self) -> None:
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args) -> None:
        for model, state in zip(self.models, self.prev_states):
            model.train(state)


##########################
# Creation utils
##########################

def build_backward(obs_dim, z_dim, cfg):
    return BackwardMap(obs_dim, z_dim, cfg.hidden_dim, cfg.hidden_layers, cfg.norm)

def build_encoder(obs_dim, action_dim, zs_dim, za_dim, zsa_dim, cfg, use_pixel=False):
    return EncoderMap(obs_dim, action_dim, zs_dim, za_dim, zsa_dim, cfg.hidden_dim,  cfg.enc_horizon,use_pixel, cfg.hidden_layers, cfg.norm) # use_pixel was at end

def build_encoder_unit_norm(obs_dim, action_dim, zs_dim, za_dim, zsa_dim, cfg, use_pixel=False):
    return EncoderMapUnitNorm(obs_dim, action_dim, zs_dim, za_dim, zsa_dim, cfg.hidden_dim,  cfg.enc_horizon,use_pixel, cfg.hidden_layers, cfg.norm) # use_pixel was at end



def build_forward(obs_dim, z_dim, action_dim, cfg, output_dim=None ,layer_norm=False):
    if cfg.ensemble_mode == "seq":
        return SequetialFMap(obs_dim, z_dim, action_dim, cfg)
    elif cfg.ensemble_mode == "vmap":
        raise NotImplementedError("vmap ensemble mode is currently not supported")
    
    assert cfg.ensemble_mode == "batch", "Invalid value for ensemble_mode. Use {'batch', 'seq', 'vmap'}"
    return _build_batch_forward(obs_dim, z_dim, action_dim, cfg, output_dim,layer_norm=layer_norm)


def build_forward_state_only(obs_dim, z_dim, cfg, output_dim=None):
    if cfg.ensemble_mode == "seq":
        return SequetialFMap(obs_dim, z_dim, 0, cfg)
    elif cfg.ensemble_mode == "vmap":
        raise NotImplementedError("vmap ensemble mode is currently not supported")
    
    assert cfg.ensemble_mode == "batch", "Invalid value for ensemble_mode. Use {'batch', 'seq', 'vmap'}"
    return _build_batch_forward(obs_dim, z_dim, 0, cfg, output_dim, parallel=False)

def _build_batch_forward(obs_dim, z_dim, action_dim, cfg, output_dim=None, parallel=True,layer_norm=False):
    if cfg.model == "residual":
        forward_cls = ResidualForwardMap
    elif cfg.model == "simple":
        if action_dim == 0:
            forward_cls = ForwardMapStateOnly
            return forward_cls(obs_dim, z_dim, cfg.hidden_dim, cfg.hidden_layers, cfg.embedding_layers, 1, output_dim)
        else:
            if layer_norm:
                forward_cls = ForwardMapLayerNorm
            else:   
                forward_cls = ForwardMap
    else:
        raise ValueError(f"Unsupported forward_map model {cfg.model}")
    num_parallel = cfg.num_parallel if parallel else 1
    return forward_cls(obs_dim, z_dim, action_dim, cfg.hidden_dim, cfg.hidden_layers, cfg.embedding_layers, num_parallel, output_dim)

def build_actor(obs_dim, z_dim, action_dim, cfg):
    if cfg.model == "residual":
        actor_cls = ResidualActor
    elif cfg.model == "simple":
        actor_cls = Actor
    else:
        raise ValueError(f"Unsupported actor model {cfg.model}")
    return actor_cls(obs_dim, z_dim, action_dim, cfg.hidden_dim, cfg.hidden_layers, cfg.embedding_layers)

def build_hierarchical_actor(action_dim):
    return HierarchicalActor(action_dim)

def build_hierarchical_actor_normal(action_dim, expl_logstd):
    return HierarchicalActorNormal(action_dim,expl_logstd)



def build_simple_critic(obs_dim, action_dim, cfg):
    return SimpleCritic(obs_dim, action_dim, cfg.residual_hidden_dim, cfg.num_parallel, cfg.residual_hidden_layers)


def build_simple_actor(obs_dim, action_dim, cfg):
    return SimpleActor(obs_dim, action_dim, cfg.hidden_dim, 1, cfg.hidden_layers)

def build_discriminator(obs_dim, z_dim, cfg):
    return Discriminator(obs_dim, z_dim, cfg.hidden_dim, cfg.hidden_layers)

def linear(input_dim, output_dim, num_parallel=1):
    if num_parallel > 1:
        return DenseParallel(input_dim, output_dim, n_parallel=num_parallel)
    return nn.Linear(input_dim, output_dim) 

def layernorm(input_dim, num_parallel=1):
    if num_parallel > 1:
        return ParallelLayerNorm([input_dim], n_parallel=num_parallel)
    return nn.LayerNorm(input_dim)


##########################
# Simple MLP models
##########################

class BackwardMap(nn.Module):
    def __init__(self, goal_dim, z_dim, hidden_dim, hidden_layers: int = 2, norm=True) -> None:
        super().__init__()
        seq = [nn.Linear(goal_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh()]
        for _ in range(hidden_layers-1):
            seq += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]
        seq += [nn.Linear(hidden_dim, z_dim)]
        if norm: 
            seq += [Norm()]
        self.net = nn.Sequential(*seq)

    def forward(self, x):
        return self.net(x)


def simple_embedding(input_dim, hidden_dim, hidden_layers, num_parallel=1):
    assert hidden_layers >= 2, "must have at least 2 embedding layers"
    seq = [linear(input_dim, hidden_dim, num_parallel), layernorm(hidden_dim, num_parallel), nn.Tanh()]
    for _ in range(hidden_layers - 2):
        seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU()]
    seq += [linear(hidden_dim, hidden_dim // 2, num_parallel), nn.ReLU()]
    return nn.Sequential(*seq)


def simple_embedding_layer_norm(input_dim, hidden_dim, hidden_layers, num_parallel=1):
    assert hidden_layers >= 2, "must have at least 2 embedding layers"
    seq = [linear(input_dim, hidden_dim, num_parallel), layernorm(hidden_dim, num_parallel), nn.Tanh()]
    for _ in range(hidden_layers - 2):
        seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU(), layernorm(hidden_dim, num_parallel)]
    seq += [linear(hidden_dim, hidden_dim // 2, num_parallel), nn.ReLU(), layernorm(hidden_dim // 2, num_parallel)]
    return nn.Sequential(*seq)

class RandomShiftsAug(nn.Module):
    def __init__(self, pad) -> None:
        super().__init__()
        self.pad = pad

    def forward(self, x) -> torch.Tensor:
        x = x.float()
        n, _, h, w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, 'replicate')
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(-1.0 + eps,
                                1.0 - eps,
                                h + 2 * self.pad,
                                device=x.device,
                                dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)

        shift = torch.randint(0,
                              2 * self.pad + 1,
                              size=(n, 1, 1, 2),
                              device=x.device,
                              dtype=x.dtype)
        shift *= 2.0 / (h + 2 * self.pad)

        grid = base_grid + shift
        return F.grid_sample(x,
                             grid,
                             padding_mode='zeros',
                             align_corners=False)



class EncoderMap(nn.Module):
    def __init__(self, goal_dim, action_dim, zs_dim, za_dim,zsa_dim, hidden_dim, enc_horizon,use_pixel=False, hidden_layers: int = 2, norm=True) -> None:
        super().__init__()
        self.zs_dim = zs_dim #512
        self.za_dim = za_dim #256
        self.zsa_dim = zsa_dim #512
        self.hdim = hidden_dim #512
        self.enc_horizon= enc_horizon #5
        self.use_pixel = use_pixel
        # import ipdb;ipdb.set_trace()
        self.repr_dim = 800
        self.history = 3
        activ = 'elu'
        self.activ = getattr(F, activ)
        if use_pixel:
            self.zs_cnn1 = nn.Conv2d(self.history * 3, 32, 3, stride=2)
            self.zs_cnn2 = nn.Conv2d(32, 32, 3, stride=2)
            self.zs_cnn3 = nn.Conv2d(32, 32, 3, stride=2)
            self.zs_cnn4 = nn.Conv2d(32, 32, 3, stride=1)
            self.zs_lin = nn.Linear(800, zs_dim)
            self.zs = self.cnn_zs
        else:
            self.zs = BaseMLP(goal_dim, self.zs_dim, self.hdim, activ, norm)
        self.za = nn.Linear(action_dim, self.za_dim)
        self.zsa = BaseMLP(self.zs_dim + self.za_dim, self.zsa_dim, self.hdim, activ, norm)
        self.model = nn.Linear(self.zsa_dim, self.zs_dim)
        self.norm = norm
        self.image_width = 64
        self.aug = RandomShiftsAug(pad=(self.image_width // 21))
        self.register_buffer('running_mean', torch.zeros(zsa_dim))
        self.register_buffer('running_std', torch.ones(zsa_dim))

    def forward(self, x):
        out = self.zs(x)
        with torch.no_grad():
            self.running_mean = 0.995 * self.running_mean + 0.005 * out.mean(dim=0)
        return out
    
    def cnn_zs(self, state: torch.Tensor):
        with torch.no_grad():
            state = self.aug(state)
        state = state/255. - 0.5
        zs = self.activ(self.zs_cnn1(state))
        zs = self.activ(self.zs_cnn2(zs))
        zs = self.activ(self.zs_cnn3(zs))
        zs = self.activ(self.zs_cnn4(zs)).reshape(state.shape[0], -1)
        return ln_activ(self.zs_lin(zs), self.activ)


    def encoder(self,state:torch.Tensor, augment=True):
        if augment:
            with torch.no_grad():
                state = self.aug(state)
        state = state/255. - 0.5
        zs = self.activ(self.zs_cnn1(state))
        zs = self.activ(self.zs_cnn2(zs))
        zs = self.activ(self.zs_cnn3(zs))
        zs = (self.zs_cnn4(zs))
        backward = self.activ(zs).reshape(state.shape[0], -1)
        backward = ln_activ(self.zs_lin(backward), self.activ)
        return zs.reshape(state.shape[0], -1),backward

    def features(self, x):
        if self.norm:
            return self.zs(x)
        else:
            return self.zs(x)-self.running_mean.reshape(1,-1)

    def model_unroll(self,zs, action):
        za = self.activ(self.za(action))
        zsa = self.zsa(torch.cat([zs, za], 1))
        zsa = self.model(zsa)

        if self.norm:
            return math.sqrt(zsa.shape[-1]) * F.normalize(zsa, dim=-1)
        else:
            return zsa


class EncoderMapUnitNorm(nn.Module):
    def __init__(self, goal_dim, action_dim, zs_dim, za_dim,zsa_dim, hidden_dim, enc_horizon,use_pixel=False, hidden_layers: int = 2, norm=True) -> None:
        super().__init__()
        self.zs_dim = zs_dim #512
        self.za_dim = za_dim #256
        self.zsa_dim = zsa_dim #512
        self.hdim = hidden_dim #512
        self.enc_horizon= enc_horizon #5
        self.use_pixel = use_pixel
        # import ipdb;ipdb.set_trace()
        self.repr_dim = 800
        self.history = 3
        activ = 'elu'
        self.activ = getattr(F, activ)
        if use_pixel:
            self.zs_cnn1 = nn.Conv2d(self.history * 3, 32, 3, stride=2)
            self.zs_cnn2 = nn.Conv2d(32, 32, 3, stride=2)
            self.zs_cnn3 = nn.Conv2d(32, 32, 3, stride=2)
            self.zs_cnn4 = nn.Conv2d(32, 32, 3, stride=1)
            self.zs_lin = nn.Linear(self.repr_dim, zs_dim)
            self.zs = self.cnn_zs
        else:
            self.zs = BaseMLPUnitNorm(goal_dim, self.zs_dim, self.hdim, activ, norm)
        self.za = nn.Linear(action_dim, self.za_dim)
        self.zsa = BaseMLPUnitNorm(self.zs_dim + self.za_dim, self.zsa_dim, self.hdim, activ, norm)
        self.model = nn.Linear(self.zsa_dim, self.zs_dim)
        self.norm = norm
        self.image_width = 64
        self.aug = RandomShiftsAug(pad=(self.image_width // 21))
        self.register_buffer('running_mean', torch.zeros(zsa_dim))
        self.register_buffer('running_std', torch.ones(zsa_dim))

    def forward(self, x):
        out = self.zs(x)
        with torch.no_grad():
            self.running_mean = 0.995 * self.running_mean + 0.005 * out.mean(dim=0)
        return out
    
    def cnn_zs(self, state: torch.Tensor):
        with torch.no_grad():
            state = self.aug(state)
        state = state/255. - 0.5
        zs = self.activ(self.zs_cnn1(state))
        zs = self.activ(self.zs_cnn2(zs))
        zs = self.activ(self.zs_cnn3(zs))
        zs = self.activ(self.zs_cnn4(zs)).reshape(state.shape[0], -1)
        # out =  ln_activ(self.zs_lin(zs), self.activ)
        if self.norm:
            out = F.normalize(self.zs_lin(zs), dim=-1)
        return out


    def encoder(self,state:torch.Tensor, augment=True):
        if augment:
            with torch.no_grad():
                state = self.aug(state)
        state = state/255. - 0.5
        zs = self.activ(self.zs_cnn1(state))
        zs = self.activ(self.zs_cnn2(zs))
        zs = self.activ(self.zs_cnn3(zs))
        zs = (self.zs_cnn4(zs))
        backward = self.activ(zs).reshape(state.shape[0], -1)
        backward = ln_activ(self.zs_lin(backward), self.activ)
        return zs.reshape(state.shape[0], -1),backward

    def features(self, x):
        if self.norm:
            return self.zs(x)
        else:
            return self.zs(x)-self.running_mean.reshape(1,-1)

    def model_unroll(self,zs, action):
        za = self.activ(self.za(action))
        zsa = self.zsa(torch.cat([zs, za], 1))
        zsa = self.model(zsa)

        if self.norm:
            return F.normalize(zsa, dim=-1)
        else:
            return zsa


class ForwardMap(nn.Module):
    def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 
                 embedding_layers: int = 2, num_parallel: int = 2, output_dim=None) -> None:
        super().__init__()
        self.z_dim = z_dim
        self.num_parallel = num_parallel
        self.hidden_dim = hidden_dim

        self.embed_z = simple_embedding(obs_dim + z_dim, hidden_dim, embedding_layers, num_parallel)
        self.embed_sa = simple_embedding(obs_dim + action_dim, hidden_dim, embedding_layers, num_parallel)

        seq = []
        for _ in range(hidden_layers):
            seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU()]
        seq += [linear(hidden_dim, output_dim if output_dim else z_dim, num_parallel)]
        self.Fs = nn.Sequential(*seq)
    
    def forward(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor):
        if self.num_parallel > 1:
            obs = obs.expand(self.num_parallel, -1, -1)
            z = z.expand(self.num_parallel, -1, -1)
            action = action.expand(self.num_parallel, -1, -1)
        z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # num_parallel x bs x h_dim // 2
        sa_embedding = self.embed_sa(torch.cat([obs, action], dim=-1)) # num_parallel x bs x h_dim // 2
        return self.Fs(torch.cat([sa_embedding, z_embedding], dim=-1))


class ForwardMapLayerNorm(nn.Module):
    def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 
                 embedding_layers: int = 2, num_parallel: int = 2, output_dim=None) -> None:
        super().__init__()
        self.z_dim = z_dim
        self.num_parallel = num_parallel
        self.hidden_dim = hidden_dim

        self.embed_z = simple_embedding_layer_norm(obs_dim + z_dim, hidden_dim, embedding_layers, num_parallel)
        self.embed_sa = simple_embedding_layer_norm(obs_dim + action_dim, hidden_dim, embedding_layers, num_parallel)

        seq = []
        for _ in range(hidden_layers):
            seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU(), layernorm(hidden_dim, num_parallel)]
        seq += [linear(hidden_dim, output_dim if output_dim else z_dim, num_parallel)]
        self.Fs = nn.Sequential(*seq)
    
    def forward(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor):
        if self.num_parallel > 1:
            obs = obs.expand(self.num_parallel, -1, -1)
            z = z.expand(self.num_parallel, -1, -1)
            action = action.expand(self.num_parallel, -1, -1)
        z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # num_parallel x bs x h_dim // 2
        sa_embedding = self.embed_sa(torch.cat([obs, action], dim=-1)) # num_parallel x bs x h_dim // 2
        return self.Fs(torch.cat([sa_embedding, z_embedding], dim=-1))



class ForwardMapStateOnly(nn.Module):
    def __init__(self, obs_dim, z_dim, hidden_dim, hidden_layers: int = 1, 
                 embedding_layers: int = 2, num_parallel: int = 1, output_dim=None) -> None:
        super().__init__()
        self.z_dim = z_dim
        self.num_parallel = num_parallel
        self.hidden_dim = hidden_dim

        self.embed_z = simple_embedding(obs_dim + z_dim, hidden_dim, embedding_layers, num_parallel)
        self.embed_s = simple_embedding(obs_dim, hidden_dim, embedding_layers, num_parallel)

        seq = []
        for _ in range(hidden_layers):
            seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU()]
        seq += [linear(hidden_dim, output_dim if output_dim else z_dim, num_parallel)]
        self.Fs = nn.Sequential(*seq)
    
    def forward(self, obs: torch.Tensor, z: torch.Tensor):
        if self.num_parallel > 1:
            obs = obs.expand(self.num_parallel, -1, -1)
            z = z.expand(self.num_parallel, -1, -1)
        z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # num_parallel x bs x h_dim // 2
        sa_embedding = self.embed_s(obs) # num_parallel x bs x h_dim // 2
        return self.Fs(torch.cat([sa_embedding, z_embedding], dim=-1))


class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class SimpleCritic(nn.Module):
    """A critic with a simple MLP architecture, relu non-linearities, and fully parallel layers"""
    def __init__(self, obs_dim, action_dim, hidden_dim, num_parallel: int = 2, num_hidden_layers: int = 2, output_dim: float = 1) -> None:
        super().__init__()
        self.num_parallel = num_parallel
        self.hidden_dim = hidden_dim
        seq = [linear(obs_dim+action_dim, hidden_dim, num_parallel),nn.ReLU() ] # nn.Tanh()]nn.ReLU()
        for _ in range(num_hidden_layers - 1):
            seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU() ] #nn.ReLU()
        seq += [linear(hidden_dim, 1, num_parallel)]
        self.Qs = nn.Sequential(*seq)

    def forward(self, obs: torch.tensor, action):
        if self.num_parallel > 1:
            obs = obs.expand(self.num_parallel, -1, -1)
            if action is not None:
                action = action.expand(self.num_parallel, -1, -1)
        if action is not None:
            h = torch.cat([obs, action], dim=-1)
        else:
            h = obs
        return self.Qs(h)
    
class SimpleActor(nn.Module):
    """A critic with a simple MLP architecture, relu non-linearities, and fully parallel layers"""
    def __init__(self, obs_dim, action_dim, hidden_dim, num_parallel: int = 1, num_hidden_layers: int = 2) -> None:
        super().__init__()
        self.num_parallel = num_parallel
        self.hidden_dim = hidden_dim
        seq = [linear(obs_dim, hidden_dim, num_parallel),nn.ReLU() ] # nn.Tanh()]nn.ReLU()
        for _ in range(num_hidden_layers - 1):
            seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU() ] #nn.ReLU()
        seq += [linear(hidden_dim, action_dim, num_parallel)]
        self.actor = nn.Sequential(*seq)

    def forward(self, obs: torch.tensor, std):
        if self.num_parallel > 1:
            obs = obs.expand(self.num_parallel, -1, -1)
        h = obs
        mu = torch.tanh(self.actor(h))
        std = torch.ones_like(mu) * std
        dist = TruncatedNormal(mu, std)
        return dist

    

    
class HierarchicalActor(nn.Module):
    """An actor with a simple MLP architecture and relu non-linearities"""
    def __init__(self, action_dim) -> None:
        super().__init__()
        self.action = torch.nn.parameter.Parameter(torch.randn((1,action_dim),requires_grad=True,dtype=torch.float32))

    def forward(self, obs):
        action_repeat = torch.repeat_interleave(self.action, obs.shape[0], dim=0)    
        return math.sqrt(action_repeat.shape[-1]) * F.normalize(action_repeat, dim=-1)


class HierarchicalActorNormal(nn.Module):
    """An actor that samples actions from a normal distribution with learnable mean and std."""
    def __init__(self, action_dim,logstd) -> None:
        super().__init__()
        self.action = torch.nn.parameter.Parameter(torch.randn((1,action_dim),requires_grad=True,dtype=torch.float32))
        # self.action = torch.nn.Parameter(torch.randn((1, action_dim), dtype=torch.float32))
        self.action_log_std = torch.nn.Parameter(torch.full((1, action_dim), fill_value=logstd, dtype=torch.float32))

    def forward(self, obs, eval=False):
        if obs is None:
            batch_size = 1
        else:    
            batch_size = obs.shape[0]

        action_mean = self.action.repeat(batch_size, 1).to(self.action.device)
        action_log_std = self.action_log_std.repeat(batch_size, 1).to(self.action.device)
        
        if eval:
            return action_mean

        epsilon = torch.randn_like(self.action).to(self.action.device).repeat(batch_size, 1)
        action = action_mean + epsilon * torch.exp(action_log_std)
        
        # Compute the log probability of the sampled action
        normal_dist = torch.distributions.Normal(action_mean, torch.exp(action_log_std.detach()))
        log_prob = normal_dist.log_prob(action.detach()).sum(dim=-1, keepdim=True)
        
        # Optionally normalize the action (note: this affects the distribution)
        action_norm = torch.norm(action, dim=-1, keepdim=True)
        action_normalized = math.sqrt(action.shape[-1]) * F.normalize(action, dim=-1)
        
        
        log_prob_normalized = log_prob
        return action_normalized, log_prob

 

class SequetialFMap(nn.Module):
    def __init__(self, obs_dim, z_dim, action_dim, cfg, output_dim=None):
        super().__init__()
        self.models = nn.ModuleList([_build_batch_forward(obs_dim, z_dim, action_dim, 
                                                          cfg, output_dim, parallel=False) for _ in range(cfg.num_parallel)])

    def forward(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        predictions = [model(obs, z, action) for model in self.models]
        return torch.stack(predictions)


class Actor(nn.Module):
    def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 
                 embedding_layers: int = 2) -> None:
        super().__init__()

        self.embed_z = simple_embedding(obs_dim + z_dim, hidden_dim, embedding_layers)
        self.embed_s = simple_embedding(obs_dim, hidden_dim, embedding_layers)

        seq = []
        for _ in range(hidden_layers):
            seq += [linear(hidden_dim, hidden_dim), nn.ReLU()]
        seq += [linear(hidden_dim, action_dim)]
        self.policy = nn.Sequential(*seq)

    def forward(self, obs, z, std):
        z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # bs x h_dim // 2
        s_embedding = self.embed_s(obs) # bs x h_dim // 2
        embedding = torch.cat([s_embedding, z_embedding], dim=-1)
        mu = torch.tanh(self.policy(embedding))
        std = torch.ones_like(mu) * std
        dist = TruncatedNormal(mu, std)
        return dist

    def get_normal(self,obs,z,std):
        z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # bs x h_dim // 2
        s_embedding = self.embed_s(obs) # bs x h_dim // 2
        embedding = torch.cat([s_embedding, z_embedding], dim=-1)
        mu = torch.tanh(self.policy(embedding))
        std = torch.ones_like(mu) * std
        # import ipdb;ipdb.set_trace()
        dist = torch.distributions.Normal(mu, std)
        return dist


class Discriminator(nn.Module):
    def __init__(self, obs_dim, z_dim, hidden_dim, hidden_layers) -> None:
        super().__init__()
        seq = [nn.Linear(obs_dim + z_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh()]
        for _ in range(hidden_layers-1):
            seq += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]
        seq += [nn.Linear(hidden_dim, 1)]
        self.trunk = nn.Sequential(*seq)

    def forward(self, obs: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        s = self.compute_logits(obs, z)
        return torch.sigmoid(s)

    def compute_logits(self, obs: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        x = torch.cat([z, obs], dim=1)
        logits = self.trunk(x)
        return logits

    def compute_reward(self, obs: torch.Tensor, z: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
        s = self.forward(obs, z)
        s = torch.clamp(s, eps, 1 - eps)
        reward = s.log() - (1 - s).log()
        return reward


##########################
# Residual models
##########################

class ResidualBlock(nn.Module):
    def __init__(self, dim, num_parallel: int = 1):
        super().__init__()
        ln = layernorm(dim, num_parallel)
        lin = linear(dim, dim, num_parallel)
        self.mlp = nn.Sequential(ln, lin, nn.Mish())

    def forward(self, x):
        return x + self.mlp(x)


class Block(nn.Module):
    def __init__(self, input_dim, output_dim, activation, num_parallel: int = 1):
        super().__init__()
        ln = layernorm(input_dim, num_parallel)
        lin = linear(input_dim, output_dim, num_parallel)
        seq = [ln, lin] + ([nn.Mish()] if activation else [])
        self.mlp = nn.Sequential(*seq)

    def forward(self, x):
        return self.mlp(x)


def residual_embedding(input_dim, hidden_dim, hidden_layers, num_parallel=1):
    assert hidden_layers >= 2, "must have at least 2 embedding layers"
    seq = [Block(input_dim, hidden_dim, True, num_parallel)]
    for _ in range(hidden_layers-2):
        seq += [ResidualBlock(hidden_dim, num_parallel)]
    seq += [Block(hidden_dim, hidden_dim // 2, True, num_parallel)]
    return nn.Sequential(*seq)


class ResidualForwardMap(nn.Module):
    def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 
                 embedding_layers: int = 2, num_parallel: int = 2, output_dim=None) -> None:
        super().__init__()
        self.z_dim = z_dim
        self.num_parallel = num_parallel
        self.hidden_dim = hidden_dim

        self.embed_z = residual_embedding(obs_dim + z_dim, hidden_dim, embedding_layers, num_parallel)
        self.embed_sa = residual_embedding(obs_dim + action_dim, hidden_dim, embedding_layers, num_parallel)

        seq = [ResidualBlock(hidden_dim, num_parallel) for _ in range(hidden_layers)]
        seq += [Block(hidden_dim, output_dim if output_dim else z_dim, False, num_parallel)]
        self.Fs = nn.Sequential(*seq)
    
    def forward(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor):
        if self.num_parallel > 1:
            obs = obs.expand(self.num_parallel, -1, -1)
            z = z.expand(self.num_parallel, -1, -1)
            action = action.expand(self.num_parallel, -1, -1)
        z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # num_parallel x bs x h_dim // 2
        sa_embedding = self.embed_sa(torch.cat([obs, action], dim=-1)) # num_parallel x bs x h_dim // 2
        return self.Fs(torch.cat([sa_embedding, z_embedding], dim=-1))


class ResidualActor(nn.Module):
    def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 
                 embedding_layers: int = 2) -> None:
        super().__init__()

        self.embed_z = residual_embedding(obs_dim + z_dim, hidden_dim, embedding_layers)
        self.embed_s = residual_embedding(obs_dim, hidden_dim, embedding_layers)

        seq = [ResidualBlock(hidden_dim) for _ in range(hidden_layers)] + [Block(hidden_dim, action_dim, False)]
        self.policy = nn.Sequential(*seq)

    def forward(self, obs, z, std):
        z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # bs x h_dim // 2
        s_embedding = self.embed_s(obs) # bs x h_dim // 2
        embedding = torch.cat([s_embedding, z_embedding], dim=-1)
        mu = torch.tanh(self.policy(embedding))
        std = torch.ones_like(mu) * std
        dist = TruncatedNormal(mu, std)
        return dist


##########################
# Helper modules
##########################

class DenseParallel(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        n_parallel: int,
        bias: bool = True,
        device=None,
        dtype=None,
        reset_params=True,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super(DenseParallel, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.n_parallel = n_parallel
        if n_parallel is None or (n_parallel == 1):
            self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
            if bias:
                self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
            else:
                self.register_parameter("bias", None)
        else:
            self.weight = nn.Parameter(
                torch.empty((n_parallel, in_features, out_features), **factory_kwargs)
            )
            if bias:
                self.bias = nn.Parameter(
                    torch.empty((n_parallel, 1, out_features), **factory_kwargs)
                )
            else:
                self.register_parameter("bias", None)
            if self.bias is None:
                raise NotImplementedError
        if reset_params:
            self.reset_parameters()

    def load_module_list_weights(self, module_list) -> None:
        with torch.no_grad():
            assert len(module_list) == self.n_parallel
            weight_list = [m.weight.T for m in module_list]
            target_weight = torch.stack(weight_list, dim=0)
            self.weight.data.copy_(target_weight.data)
            if self.bias:
                bias_list = [ln.bias.unsqueeze(0) for ln in module_list]
                target_bias = torch.stack(bias_list, dim=0)
                self.bias.data.copy_(target_bias.data)

    # TODO why do these layers have their own reset scheme?
    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / np.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        if self.n_parallel is None or (self.n_parallel == 1):
            return F.linear(input, self.weight, self.bias)
        else:
            return torch.baddbmm(self.bias, input, self.weight)

    def extra_repr(self) -> str:
        return "in_features={}, out_features={}, n_parallel={}, bias={}".format(
            self.in_features, self.out_features, self.n_parallel, self.bias is not None
        )


class ParallelLayerNorm(nn.Module):
    def __init__(self, normalized_shape, n_parallel, eps=1e-5, elementwise_affine=True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(ParallelLayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = [normalized_shape, ]
        assert len(normalized_shape) == 1
        self.n_parallel = n_parallel
        self.normalized_shape = list(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            if n_parallel is None or (n_parallel == 1):
                self.weight = nn.Parameter(torch.empty([*self.normalized_shape], **factory_kwargs))
                self.bias = nn.Parameter(torch.empty([*self.normalized_shape], **factory_kwargs))
            else:
                self.weight = nn.Parameter(torch.empty([n_parallel, 1, *self.normalized_shape], **factory_kwargs))
                self.bias = nn.Parameter(torch.empty([n_parallel, 1, *self.normalized_shape], **factory_kwargs))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.elementwise_affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)
    
    def load_module_list_weights(self, module_list) -> None:
        with torch.no_grad():
            assert len(module_list) == self.n_parallel
            if self.elementwise_affine:
                ln_weights = [ln.weight.unsqueeze(0) for ln in module_list]
                ln_biases = [ln.bias.unsqueeze(0) for ln in module_list]
                target_ln_weights = torch.stack(ln_weights, dim=0)
                target_ln_bias = torch.stack(ln_biases, dim=0)
                self.weight.data.copy_(target_ln_weights.data)
                self.bias.data.copy_(target_ln_bias.data)


    def forward(self, input):
        norm_input = F.layer_norm(
            input, self.normalized_shape, None, None, self.eps)
        if self.elementwise_affine:
            return (norm_input * self.weight) + self.bias
        else:
            return norm_input

    def extra_repr(self) -> str:
        return '{normalized_shape}, eps={eps}, ' \
               'elementwise_affine={elementwise_affine}'.format(**self.__dict__)


class TruncatedNormal(pyd.Normal):
    def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6) -> None:
        super().__init__(loc, scale, validate_args=False)
        self.low = low
        self.high = high
        self.eps = eps
        self.noise_upper_limit = high - self.loc
        self.noise_lower_limit = low - self.loc

    def _clamp(self, x) -> torch.Tensor:
        clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
        x = x - x.detach() + clamped_x.detach()
        return x

    def sample(self, clip=None, sample_shape=torch.Size()) -> torch.Tensor:  # type: ignore
        shape = self._extended_shape(sample_shape)
        eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
        eps *= self.scale
        if clip is not None:
            eps = torch.clamp(eps, -clip, clip)
        x = self.loc + eps
        return self._clamp(x)


class Norm(nn.Module):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x) -> torch.Tensor:
        return math.sqrt(x.shape[-1]) * F.normalize(x, dim=-1)




def weight_init(layer: torch.nn.modules):
    if isinstance(layer, (nn.Linear, nn.Conv2d)):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_uniform_(layer.weight.data, gain)
        if hasattr(layer.bias, 'data'): layer.bias.data.fill_(0.0)


def ln_activ(x: torch.Tensor, activ):
    x = F.layer_norm(x, (x.shape[-1],))
    return activ(x)


class BaseMLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hdim: int, activ: str='elu', norm=True):
        super().__init__()
        self.l1 = nn.Linear(input_dim, hdim)
        self.l2 = nn.Linear(hdim, hdim)
        self.l3 = nn.Linear(hdim, output_dim)
        self.norm = norm
        self.activ = getattr(F, activ)
        self.apply(weight_init)


    def forward(self, x: torch.Tensor):
        y = ln_activ(self.l1(x), self.activ)
        y = ln_activ(self.l2(y), self.activ)
        y = self.l3(y)
        if self.norm:
            return math.sqrt(y.shape[-1]) * F.normalize(y, dim=-1)
        return y


class BaseMLPUnitNorm(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hdim: int, activ: str='elu', norm=True):
        super().__init__()
        self.l1 = nn.Linear(input_dim, hdim)
        self.l2 = nn.Linear(hdim, hdim)
        self.l3 = nn.Linear(hdim, output_dim)
        self.norm = norm
        self.activ = getattr(F, activ)
        self.apply(weight_init)


    def forward(self, x: torch.Tensor):
        y = ln_activ(self.l1(x), self.activ)
        y = ln_activ(self.l2(y), self.activ)
        y = self.l3(y)
        if self.norm:
            return F.normalize(y, dim=-1)
        return y
