

import torch
import torch.nn as nn
from typing import List
import argparse
from tqdm import tqdm
from isaaclab.utils import configclass
from dataclasses import MISSING
from rsl_rl.rsl_rl.utils.utils import resolve_nn_activation
from abc import abstractmethod
from rsl_rl.rsl_rl.addons.submodule import NNSubmodule



def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"):
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        activation = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation)
        return nn.Sequential(*network_layers)





# Define MLP Model
class KinematicMLP(NNSubmodule):
    def __init__(self, input_dim, num_bodies, num_output_features_per_body, hidden_dims: list[int], 
                 backbone_output_dim, input_slice = [0, 0], weight_path = None, **kwargs):
        super(KinematicMLP, self).__init__()
        self.module_type = "kinematic"
        self.input_dim = input_dim
        self.backbone_output_dim = backbone_output_dim
        self.input_slice = input_slice
        self.num_bodies = num_bodies
        self.num_output_features_per_body = num_output_features_per_body
        self.hidden_dim = hidden_dims
        self.backbone_output_dim = backbone_output_dim
        self.backbone = build_mlp(input_dim, hidden_dims, backbone_output_dim, activation_name="elu")
        self.output_layer = nn.Linear(backbone_output_dim, num_bodies*num_output_features_per_body)
        if weight_path is not None:
            try:
                self.load_state_dict(torch.load(weight_path, weights_only=True))
                self.requires_grad_(False)
                self.weight_path = weight_path
                self.eval()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")

    @torch.no_grad()
    def get_latents(self, observations): 
        latents = []
        x = observations[:, self.input_slice[0]:self.input_slice[1]]
        for layer in self.backbone:
            x = layer(x)
            if isinstance(layer, nn.Linear):
                latents.append(x)
        return latents            

    @torch.no_grad()
    def forward_RL(self, x):
        """
            This method is used to forward the input tensor through the backbone of the model in no_grad mode.
            To be used in RL stage. 
        """
        assert self.input_slice is not None
        x = self.backbone(x[:, self.input_slice[0]:self.input_slice[1]])
        return x
    
    @torch.no_grad()
    def forward_RL_get_final_output(self, x):
        """
            This method is used to forward from the full observation tensor.
            To be used for debugging and visualization.
        """
        x = self.forward_RL(x)
        x = self.output_layer(x)
        return x

    def forward(self, x):
        x = self.backbone(x)
        x = self.output_layer(x)
        return x



@configclass
class KinematicSubmoduleConfig:
    class_name: str = "KinematicMLP"
    """The class name of the submodule."""

    input_dim: int = 12
    """The input dimension of the submodule. 12 for quadruped since we have 4 legs with 3 joints each."""

    num_bodies: int = 8
    """The number of bodies in the system. 17 for quadruped."""

    num_output_features_per_body: int = 6
    """The number of output features per body. 6 for pose, 3 for translation."""

    hidden_dims: list[int] = [128, 128, 128]
    """The hidden dimension of the MLP."""

    backbone_output_dim = 30
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""

    input_slice: [int, int] = MISSING
    """The slice of the input tensor within the total observation input. This should be set in RL config instances."""

    weight_path: str | None = None
    """The path to the pretrained model weights. If None, the model will be initialized randomly."""


