

import torch
import torch.nn as nn
from typing import List
import argparse
from tqdm import tqdm
from isaaclab.utils import configclass
from dataclasses import MISSING
from rsl_rl.rsl_rl.utils.utils import resolve_nn_activation
from abc import abstractmethod
from rsl_rl.rsl_rl.addons.submodule import NNSubmodule



def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"):
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        activation = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation)
        return nn.Sequential(*network_layers)


class MLPResidualBlock(nn.Module):
    """
        MLP residual block, which is a feedforward neural network with skip connections.
        Hidden layers must be at least 2.
    """
    def __init__(self, hidden_dim, hidden_layers, activation):
        super().__init__()
        assert hidden_layers >= 2
        self.hidden_layers = hidden_layers
        if type(activation) is str:
            activation = resolve_nn_activation(activation)
        layers = []
        for layer_index in range(hidden_layers):
            if layer_index == hidden_layers - 1:
                layers.append(nn.Linear(hidden_dim, hidden_dim))
            else:
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(activation)
        self.mlp_block = nn.Sequential(*layers)
        self.final_activation = activation

    def forward(self, x):
        return self.final_activation(x + self.mlp_block(x))
    

class ResNetBlocks(NNSubmodule):
    def __init__(
        self,
        input_dim: int,
        backbone_output_dim: int,
        num_residual_blocks: int =8,
        residual_block_hidden_layers=2,
        residual_block_hidden_dim=128,
        input_slice: List[int] = [0, 0],
        activation="elu",
    ):
        self.input_dim = input_dim
        self.backbone_output_dim = backbone_output_dim
        self.input_slice = input_slice
        self.module_type = "resnet_blocks"
        super().__init__()
        activation = resolve_nn_activation(activation)

        residual_blocks = []
        ppc_first_layer = nn.Sequential(nn.Linear(input_dim, residual_block_hidden_dim), activation)
        residual_blocks.append(ppc_first_layer)
        for layer_index in range(num_residual_blocks):
            residual_blocks.append(
                MLPResidualBlock(
                    residual_block_hidden_dim,
                    residual_block_hidden_layers,
                    activation
                )
            )
        residual_blocks.append(nn.Linear(residual_block_hidden_dim, backbone_output_dim))
        self.residual_blocks = nn.Sequential(*residual_blocks)

    @torch.no_grad()
    def get_latent(self, observations):
        assert self.input_slice is not None
        x = observations[:, self.input_slice[0]:self.input_slice[1]]
        resnet_features = []
        for layer in self.residual_blocks:
            x = layer(x)
            resnet_features.append(x)
        return torch.stack(resnet_features, dim=0) # [num_layers, B, residual_block_hidden_dim]
    
    def forward_RL(self, observations):
        return self.forward(observations)

    def forward(self, x):
        assert self.input_slice is not None
        out = self.residual_blocks(x[:, self.input_slice[0]:self.input_slice[1]])
        return out



@configclass
class ResNetBlocksConfig:
    class_name: str = "ResNetBlocks"
    """The class name of the submodule."""

    input_dim: int = 33
    """The input dimension of the submodule. 12 for quadruped since we have 4 legs with 3 joints each."""

    backbone_output_dim = 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""

    input_slice: [int, int] = MISSING
    """The slice of the input tensor within the total observation input. This should be set in RL config instances."""


    num_residual_blocks: int = 8
    """The number of residual blocks in the MLP."""

    residual_block_hidden_layers: int = 2
    """The number of hidden layers in the residual block."""

    residual_block_hidden_dim: int = 128
    """The hidden dimension of the residual block."""

    activation: str = "elu"
    """The activation function of the residual block."""



