

import torch
import torch.nn as nn
from typing import List, Union
import argparse
from tqdm import tqdm
from isaaclab.utils import configclass
from dataclasses import MISSING
from rsl_rl.rsl_rl.utils.utils import resolve_nn_activation
from abc import abstractmethod
from rsl_rl.rsl_rl.addons.submodule import NNSubmodule
from itertools import chain



def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"):
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        activation = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation)
        return nn.Sequential(*network_layers)


def prepare_states_tensor_from_obs(obs):
    """
        input: obs_t: [batch_size, 318]. See P4RL-Pre-Inv-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
        output:
    """
    b=obs.shape[0]
    history_length_in_one_obs = 6
    jp, a, lin_vel, ang_vel, grav = obs[:, 0:72], obs[:, 72:144], obs[:, 144:162], obs[:, 162:180], obs[:, 180:198]
    jp, a, lin_vel, ang_vel, grav = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1)
    obs_segment = torch.cat([jp, lin_vel, ang_vel, grav], dim=-1) # [b, 6, 21]
    return obs_segment[:, 1:, :] # [b, 5, 21]


# Define MLP Model
class InvDynamicMLP(NNSubmodule):
    def __init__(self, 
                 input_dim_states, 
                 input_dim_actions, 
                 input_timesteps, 
                 output_dim, 
                 hidden_dims: list[int], 
                 representation_dim, 
                 input_slice_actions = [0, 0], 
                 weight_path = None, 
                 mode = "joints_only",
                 finetune_frozen = False,
                 **kwargs):
        super(InvDynamicMLP, self).__init__()
        if kwargs:
            print(
                "InvDynamicMLP.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.module_type = "inv_dynamic"

        self.input_slice_actions = input_slice_actions
        self.input_dim_states = input_dim_states
        self.input_dim_actions = input_dim_actions
        self.input_timesteps = input_timesteps

        self.mode = mode

        penultimate_dim = 128 # TODO: this probably should be a config

        self.hidden_dim = hidden_dims
        self.state_encoder = build_mlp(input_dim_states, [128], representation_dim, activation_name="elu")

        # all pretrained modules
        if input_dim_actions != 0:
            self.input_dim_actions = input_dim_actions
            self.action_encoder_pretrain = build_mlp(input_dim_actions, [128], representation_dim, activation_name="elu")
            self.main_MLP = build_mlp(representation_dim*2*input_timesteps, hidden_dims, penultimate_dim, activation_name="elu")
        else:
            self.input_dim_actions = 0
            print("The inverse dynamics model does not depend on history action info.")
            self.main_MLP = build_mlp(representation_dim*(input_timesteps+1), hidden_dims, penultimate_dim, activation_name="elu")
        # self.main_MLP = build_mlp(representation_dim, hidden_dims, backbone_output_dim, activation_name="elu")

        self.output_layer = nn.Linear(penultimate_dim, output_dim)
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The inv dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path
                # only instantiate the action encoder for RL when loading weights
                self.critic_output_layer = nn.Linear(penultimate_dim, 1)
                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")

    def freeze_pretrained_weights(self):
        self.state_encoder.requires_grad_(False)
        self.action_encoder_pretrain.requires_grad_(False)
        self.main_MLP.requires_grad_(False)
        self.output_layer.requires_grad_(False)
          
    def forward_RL(self, obs, itm_actions, act_rep_in_raw=True, critic=False):
        """
            This method is used to forward the observation tensor through the backbone of the model.
            To be used in RL stage. 
        """
        assert self.input_slice_actions is not None
        if self.mode == "joints_only":
            raise NotImplementedError("The inv dynamics module is not implemented for joints_only mode.")
        elif self.mode == "joints_plus_base":
            states_tensor = prepare_states_tensor_from_obs(obs)
        else:
            raise Exception("Unknown mode for the inv dynamics module. Please check the config code.")
        
        actions_tensor_known = obs[:, self.input_slice_actions[0]:self.input_slice_actions[1]]
        actions_tensor_known = actions_tensor_known.reshape(actions_tensor_known.shape[0], self.input_timesteps - 1, self.input_dim_actions)
        
        states_tensor = states_tensor.reshape(states_tensor.shape[0], self.input_timesteps, self.input_dim_states)
        embedded_states = torch.flatten(self.state_encoder(states_tensor), start_dim=1)
        
        # TODO: if not directly in state space, following should be changed
        if act_rep_in_raw:
            pseudo_next_p_embedding = self.state_encoder(itm_actions)
        else:
            pseudo_next_p_embedding = itm_actions
        embedded_actions = torch.flatten(self.action_encoder_pretrain(actions_tensor_known), start_dim=1)
        x = torch.cat((embedded_states, pseudo_next_p_embedding, embedded_actions), dim=-1)
        x = self.main_MLP(x)
        if critic:
            x = self.critic_output_layer(x)
        else:
            x = self.output_layer(x)
        return x


    def forward(self, states, actions=None):
        """
            args:
                states: [batch_size, input_timesteps+1, input_dim_states]
                actions: [batch_size, input_timesteps-1, input_dim_actions]
        """
        embedded_states = torch.flatten(self.state_encoder(states), start_dim=1)
        if self.input_dim_actions != 0:
            embedded_actions = torch.flatten(self.action_encoder_pretrain(actions), start_dim=1)
            x = torch.cat((embedded_states, embedded_actions), dim=-1)
        else:
            x = embedded_states
        x = self.main_MLP(x)
        x = self.output_layer(x)
        return x
    
    def get_optim_group_for_reduced_lr(self):
        return chain(self.state_encoder.parameters(), self.action_encoder_pretrain.parameters(), self.main_MLP.parameters(), self.output_layer.parameters())
    
    def get_optim_group_for_base_lr(self):
        return chain(self.critic_output_layer.parameters())  



@configclass
class InvDynamicsSubmoduleConfig:
    class_name: str = "InvDynamicMLP"
    """The class name of the submodule."""

    input_dim_states: int = MISSING

    input_dim_actions: int = MISSING

    input_timesteps: int = MISSING
    """The input dimension of the submodule. 12 for quadruped since we have 4 legs with 3 joints each."""

    output_dim: int = MISSING
    """The output dimension of the submodule. 6 for pose, 3 for translation."""

    hidden_dims: list[int] = [512, 256, 128]
    """The hidden dimension of the MLP."""

    representation_dim: int = 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""

    input_slice_actions: [int, int] = MISSING
    """The slice of the action tensor within the total action input. This should be set in RL config instances."""

    backbone_output_dim: int = 128 + 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""
    
    weight_path: str | None = None
    """The path to the pretrained model weights. If None, the model will be initialized from scratch."""

    mode: str = "joints_only" 
    """The mode of the model. Can be "joints_only" or "joints_plus_base". This does not affect training, only affects the inference during RL."""

    finetune_frozen: bool = False
    """Whether to freeze the pretrained weights during finetuning. If True, the pretrained weights will not be updated during finetuning."""