

import torch
import torch.nn as nn
from typing import List, Union
import argparse
from tqdm import tqdm
from isaaclab.utils import configclass
from dataclasses import MISSING
from rsl_rl.rsl_rl.utils.utils import resolve_nn_activation
from abc import abstractmethod
from rsl_rl.rsl_rl.addons.submodule import NNSubmodule
from itertools import chain



def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"):
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        activation = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation)
        return nn.Sequential(*network_layers)


def prepare_states_tensor_from_obs(obs):
    """
        obs_t: [batch_size, 318]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
    """
    b=obs.shape[0]
    history_length_in_one_obs = 6
    jp, a, lin_vel, ang_vel, grav = obs[:, 0:72], obs[:, 72:144], obs[:, 144:162], obs[:, 162:180], obs[:, 180:198]
    jp, a, lin_vel, ang_vel, grav = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1)
    obs_segment = torch.cat([jp, lin_vel, ang_vel, grav], dim=-1) # [b, 6, 21]
    return obs_segment[:, 1:, :] # [b, 5, 21]


def prepare_input_for_pseudo_action_encoder(obs):
    """
        obs_t: [batch_size, 318]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
    """
    b=obs.shape[0]
    history_length_in_one_obs = 6
    jp, a, lin_vel, ang_vel, grav, jv = obs[:, 0:72], obs[:, 72:144], obs[:, 144:162], obs[:, 162:180], obs[:, 180:198], obs[:, 198:270]
    jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1), jv.reshape(b, history_length_in_one_obs, -1)
    command = obs[..., -3:]
    obs_segment = torch.cat([jp, a, lin_vel, ang_vel, grav, jv], dim=-1) # [b, 6, 45]
    return torch.cat([obs_segment[:, -1, :], command], dim=-1) # [b, 5, 21]


# Define MLP Model
class DynamicMLP(NNSubmodule):
    def __init__(self, input_dim_states, input_dim_actions, input_timesteps, output_dim, hidden_dims: list[int], representation_dim, 
                 backbone_output_dim,
                 input_obs_dim_RL: Union[int, None] = None, RL_crown_dims = [128, 128],
                 input_slice_states = [0, 0], input_slice_actions = [0, 0], input_slice_policy = [0, 0], weight_path = None, 
                 mode = "joints_only", finetune_frozen = False,
                 **kwargs):
        super(DynamicMLP, self).__init__()
        if kwargs:
            print(
                "DynamicMLP.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.module_type = "dynamic"

        self.backbone_output_dim = backbone_output_dim
        self.input_slice_states = input_slice_states
        self.input_slice_actions = input_slice_actions
        self.input_slice_policy = input_slice_policy

        self.input_dim_states = input_dim_states
        self.input_dim_actions = input_dim_actions
        self.input_timesteps = input_timesteps

        self.mode = mode

        penultimate_dim = backbone_output_dim - representation_dim

        self.hidden_dim = hidden_dims
        # all pretrained modules
        if input_dim_actions != 0:
            self.input_dim_actions = input_dim_actions
            self.action_encoder_pretrain = build_mlp(input_dim_actions, [128], representation_dim, activation_name="elu")
        else:
            self.input_dim_actions = 0
            print("No action encoder for pretraining, this model now is a series forecasting model for states. You should not see this normally.")
        self.state_encoder = build_mlp(input_dim_states, [128], representation_dim, activation_name="elu")
        self.main_MLP = build_mlp(2*representation_dim*input_timesteps, hidden_dims, penultimate_dim, activation_name="elu")
        # self.main_MLP = build_mlp(representation_dim, hidden_dims, backbone_output_dim, activation_name="elu")

        self.output_layer = nn.Linear(penultimate_dim, output_dim)
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path
                # only instantiate the action encoder for RL when loading weights
                self.action_encoder_RL = build_mlp(input_obs_dim_RL, RL_crown_dims, representation_dim, activation_name="elu")
                self.action_encoder_RL.requires_grad_(True)
                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")

    def freeze_pretrained_weights(self):
        self.state_encoder.requires_grad_(False)
        self.action_encoder_pretrain.requires_grad_(False)
        self.main_MLP.requires_grad_(False)
          
    def forward_RL(self, x):
        """
            This method is used to forward the observation tensor through the backbone of the model.
            To be used in RL stage. 
        """
        if self.mode == "joints_only":
            assert self.input_slice_actions is not None and self.input_slice_states is not None
            states_tensor = x[:, self.input_slice_states[0]:self.input_slice_states[1]]
            states_tensor = states_tensor.reshape(states_tensor.shape[0], self.input_timesteps, self.input_dim_states)
            actions_tensor_known = x[:, self.input_slice_actions[0]:self.input_slice_actions[1]]
            actions_tensor_known = actions_tensor_known.reshape(actions_tensor_known.shape[0], self.input_timesteps - 1, self.input_dim_actions)
            embedded_states = torch.flatten(self.state_encoder(states_tensor), start_dim=1)
            
            pseudo_a_t_embedding = self.action_encoder_RL(x[:, self.input_slice_policy[0]:self.input_slice_policy[1]])
            embedded_actions = torch.flatten(self.action_encoder_pretrain(actions_tensor_known), start_dim=1)
            x = torch.cat((embedded_states, embedded_actions, pseudo_a_t_embedding), dim=-1)
            x = self.main_MLP(x)
            return torch.cat((pseudo_a_t_embedding, x), dim=-1)
        
        elif self.mode == "joints_plus_base":
            assert self.input_slice_actions is not None and self.input_slice_states is not None
            states_tensor = prepare_states_tensor_from_obs(x)
            actions_tensor_known = x[:, self.input_slice_actions[0]:self.input_slice_actions[1]]
            actions_tensor_known = actions_tensor_known.reshape(actions_tensor_known.shape[0], self.input_timesteps - 1, self.input_dim_actions)
            embedded_states = torch.flatten(self.state_encoder(states_tensor), start_dim=1)
            
            pseudo_a_t_embedding = self.action_encoder_RL(prepare_input_for_pseudo_action_encoder(x))
            embedded_actions = torch.flatten(self.action_encoder_pretrain(actions_tensor_known), start_dim=1)
            x = torch.cat((embedded_states, embedded_actions, pseudo_a_t_embedding), dim=-1)
            x = self.main_MLP(x)
            return torch.cat((pseudo_a_t_embedding, x), dim=-1)
        else:
            raise Exception("Unknown mode for the dynamics module. Please check the config code.")

    
    @torch.no_grad()
    def forward_RL_get_final_output(self, x):
        """
            This method is used to forward from the full observation tensor.
            To be used for debugging and visualization.
        """
        x = self.forward_RL(x)
        x = self.output_layer(x)
        return x

    def forward(self, states, actions):
        """
            args:
                states: [batch_size, input_timesteps, input_dim_states]
                actions: [batch_size, input_timesteps, input_dim_actions]
        """
        embedded_states = torch.flatten(self.state_encoder(states), start_dim=1)
        if self.input_dim_actions != 0:
            embedded_actions = torch.flatten(self.action_encoder_pretrain(actions), start_dim=1)
            x = torch.cat((embedded_states, embedded_actions), dim=-1)
        else:
            x = embedded_states
        x = self.main_MLP(x)
        x = self.output_layer(x)
        return x



@configclass
class DynamicsSubmoduleConfig:
    class_name: str = "DynamicMLP"
    """The class name of the submodule."""

    input_dim_states: int = MISSING

    input_dim_actions: int = MISSING

    input_timesteps: int = MISSING
    """The input dimension of the submodule. 12 for quadruped since we have 4 legs with 3 joints each."""

    output_dim: int = MISSING
    """The output dimension of the submodule. 6 for pose, 3 for translation."""

    hidden_dims: list[int] = [512, 256, 128]
    """The hidden dimension of the MLP."""

    RL_crown_dims: list[int] = [128, 128]
    """The hidden dimension of the MLP for RL. This should be set in RL config instances."""

    representation_dim: int = 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""

    input_slice_states: [int, int] = MISSING
    """The slice of the input tensor within the total observation input. This should be set in RL config instances."""

    input_slice_actions: [int, int] = MISSING
    """The slice of the action tensor within the total action input. This should be set in RL config instances."""

    input_slice_policy: [int, int] = MISSING

    input_obs_dim_RL: int | None = None
    """The input dimension of the observation tensor in RL stage. This should be set in RL config instances."""

    backbone_output_dim: int = 128 + 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""
    
    weight_path: str | None = None
    """The path to the pretrained model weights. If None, the model will be initialized from scratch."""

    mode: str = "joints_only" 
    """The mode of the model. Can be "joints_only" or "joints_plus_base". This does not affect training, only affects the inference during RL."""

    finetune_frozen: bool = False
    """Whether to freeze the pretrained weights during finetuning. If True, the pretrained weights will not be updated during finetuning."""



class DynamicMLPForHAC(NNSubmodule):
    """
        This class is used to build the dynamics module for the hierarchical actor-critic model.
        This class can load the pretrained weights from DynamicMLP class, but runs RL forward pass in compatible way to 
        Hierarchical Actor-Critic class.
    """
    def __init__(self, input_dim_states, input_dim_actions, input_timesteps, output_dim, hidden_dims: list[int], representation_dim, 
                 backbone_output_dim,
                 input_slice_actions = [0, 0], weight_path = None, 
                 mode = "joints_only", finetune_frozen = False,
                 **kwargs):
        super(DynamicMLPForHAC, self).__init__()
        if kwargs:
            print(
                "DynamicMLP.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.module_type = "dynamic"

        self.backbone_output_dim = backbone_output_dim
        self.input_slice_actions = input_slice_actions

        self.input_dim_states = input_dim_states
        self.input_dim_actions = input_dim_actions
        self.input_timesteps = input_timesteps

        self.mode = mode

        penultimate_dim = backbone_output_dim - representation_dim

        self.hidden_dim = hidden_dims
        # all pretrained modules
        if input_dim_actions != 0:
            self.input_dim_actions = input_dim_actions
            self.action_encoder_pretrain = build_mlp(input_dim_actions, [128], representation_dim, activation_name="elu")
        else:
            self.input_dim_actions = 0
            print("No action encoder for pretraining, this model now is a series forecasting model for states. You should not see this normally.")
        self.state_encoder = build_mlp(input_dim_states, [128], representation_dim, activation_name="elu")
        self.main_MLP = build_mlp(2*representation_dim*input_timesteps, hidden_dims, penultimate_dim, activation_name="elu")
        # self.main_MLP = build_mlp(representation_dim, hidden_dims, backbone_output_dim, activation_name="elu")

        self.output_layer = nn.Linear(penultimate_dim, output_dim)
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path
                
                # self.critic_output_layer = nn.Linear(penultimate_dim, 1)
                # self.action_output_layer = nn.Linear(penultimate_dim, self.input_dim_actions)

                self.critic_output_layer = build_mlp(penultimate_dim, [64, 64], 1, activation_name="elu")
                self.action_output_layer = build_mlp(penultimate_dim, [64, 64], self.input_dim_actions, activation_name="elu")

                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")

    def freeze_pretrained_weights(self):
        self.state_encoder.requires_grad_(False)
        self.action_encoder_pretrain.requires_grad_(False)
        self.main_MLP.requires_grad_(False)

    def get_optim_group_for_reduced_lr(self):
        return chain(self.state_encoder.parameters(), self.action_encoder_pretrain.parameters(), self.main_MLP.parameters(), self.output_layer.parameters())
    
    def get_optim_group_for_base_lr(self):
        return chain(self.critic_output_layer.parameters(), self.action_output_layer.parameters())  

    def forward_RL(self, obs, itm_actions, act_rep_in_raw=False, critic=False):
        """
            This method is used to forward the observation tensor through the backbone of the model.
            To be used in RL stage. 
        """
        assert self.input_slice_actions is not None
        if self.mode == "joints_only":
            states_tensor = obs[:, self.input_slice_states[0]:self.input_slice_states[1]]
        elif self.mode == "joints_plus_base":
            states_tensor = prepare_states_tensor_from_obs(obs)
        else:
            raise Exception("Unknown mode for the dynamics module. Please check the config code.")
        
        actions_tensor_known = obs[:, self.input_slice_actions[0]:self.input_slice_actions[1]]
        actions_tensor_known = actions_tensor_known.reshape(actions_tensor_known.shape[0], self.input_timesteps - 1, self.input_dim_actions)
        
        states_tensor = states_tensor.reshape(states_tensor.shape[0], self.input_timesteps, self.input_dim_states)
        embedded_states = torch.flatten(self.state_encoder(states_tensor), start_dim=1)
        
        # TODO: if not directly in state space, following should be changed
        if act_rep_in_raw:
            pseudo_next_p_embedding = self.state_encoder(itm_actions)
        else:
            pseudo_next_p_embedding = itm_actions
        embedded_actions = torch.flatten(self.action_encoder_pretrain(actions_tensor_known), start_dim=1)
        x = torch.cat((embedded_states, pseudo_next_p_embedding, embedded_actions), dim=-1)
        x = self.main_MLP(x)
        if critic:
            x = self.critic_output_layer(x)
        else:
            x = self.action_output_layer(x)
        return x

    
    def forward(self, states, actions):
        """
            args:
                states: [batch_size, input_timesteps, input_dim_states]
                actions: [batch_size, input_timesteps, input_dim_actions]
        """
        embedded_states = torch.flatten(self.state_encoder(states), start_dim=1)
        if self.input_dim_actions != 0:
            embedded_actions = torch.flatten(self.action_encoder_pretrain(actions), start_dim=1)
            x = torch.cat((embedded_states, embedded_actions), dim=-1)
        else:
            x = embedded_states
        x = self.main_MLP(x)
        x = self.output_layer(x)
        return x
    


@configclass
class DynamicsSubmoduleForHACConfig:
    class_name: str = "DynamicMLPForHAC"
    """The class name of the submodule."""

    input_dim_states: int = MISSING

    input_dim_actions: int = MISSING

    input_timesteps: int = MISSING
    """The input dimension of the submodule. 12 for quadruped since we have 4 legs with 3 joints each."""

    output_dim: int = MISSING
    """The output dimension of the submodule. 6 for pose, 3 for translation."""

    hidden_dims: list[int] = [512, 256, 128]
    """The hidden dimension of the MLP."""

    representation_dim: int = 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""

    input_slice_actions: [int, int] = MISSING
    """The slice of the action tensor within the total action input. This should be set in RL config instances."""

    backbone_output_dim: int = 128 + 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""
    
    weight_path: str | None = None
    """The path to the pretrained model weights. If None, the model will be initialized from scratch."""

    mode: str = "joints_only" 
    """The mode of the model. Can be "joints_only" or "joints_plus_base". This does not affect training, only affects the inference during RL."""

    finetune_frozen: bool = False
    """Whether to freeze the pretrained weights during finetuning. If True, the pretrained weights will not be updated during finetuning."""