

import torch
import torch.nn as nn
from typing import List, Union, Tuple
import argparse
from tqdm import tqdm
from isaaclab.utils import configclass
from dataclasses import MISSING
from rsl_rl.rsl_rl.utils.utils import resolve_nn_activation, unpad_trajectories
from abc import abstractmethod
from rsl_rl.rsl_rl.addons.submodule import NNSubmodule
from itertools import chain
from einops import rearrange



def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu"):
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        activation = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation)
        return nn.Sequential(*network_layers)


# from p4rl.tasks.pedipulation.config.anymal_d.pedipulation_base import ObservationsWithHistoryCfg
# WARNING: note that the following functions are only useble if the observation config is exactly the same as "ObservationsWithHistoryCfg"

def extract_state_action_input_for_dynamics_module(obs) -> Tuple:
    """
        Extract s_{t-1}, s_{t} and a_{t-1} 
    """
    history_length_in_one_obs = 6
    jp, a, lin_vel, ang_vel, grav = obs[..., 0:72], obs[..., 72:144], obs[..., 144:162], obs[..., 162:180], obs[..., 180:198]
    jp, a, lin_vel, ang_vel, grav = jp.reshape(*jp.shape[:-1], history_length_in_one_obs, -1), a.reshape(*a.shape[:-1], history_length_in_one_obs, -1), lin_vel.reshape(*lin_vel.shape[:-1], history_length_in_one_obs, -1), ang_vel.reshape(*ang_vel.shape[:-1], history_length_in_one_obs, -1), grav.reshape(*grav.shape[:-1], history_length_in_one_obs, -1)
    obs_segment = torch.cat([jp, lin_vel, ang_vel, grav], dim=-1)
    s_tm1, s_t, a_tm1 = obs_segment[..., -2, :], obs_segment[..., -1, :], a[..., -1, :]
    return s_tm1, s_t, a_tm1


def get_last_timestep_obs_idx():
    idx = torch.zeros(273, dtype=torch.bool)
    idx[60:72] = True
    idx[132:144] = True
    idx[159:162] = True
    idx[177:180] = True
    idx[195:198] = True
    idx[258:270] = True

    idx[270:273] = True
    return idx


# Define MLP Model
class DynamicGRU(NNSubmodule):
    def __init__(self, input_dim_states, input_dim_actions, input_timesteps, output_dim, hidden_dims: list[int], representation_dim, 
                 backbone_output_dim,
                 input_obs_dim_RL: Union[int, None] = None, 
                 RL_crown_dims = [128, 128],
                 weight_path = None, 
                 mode = "joints_only", finetune_frozen = False,
                 **kwargs):
        super(DynamicGRU, self).__init__()
        if kwargs:
            print(
                "DynamicLSTM.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.module_type = "dynamic"

        self.backbone_output_dim = backbone_output_dim
        self.input_slice_policy = get_last_timestep_obs_idx()

        self.input_dim_states = input_dim_states
        self.input_dim_actions = input_dim_actions
        self.input_timesteps = input_timesteps

        self.mode = mode

        penultimate_dim = backbone_output_dim - representation_dim

        self.hidden_dim = hidden_dims
        # all pretrained modules
        if input_dim_actions != 0:
            self.input_dim_actions = input_dim_actions
            self.action_encoder_pretrain = build_mlp(input_dim_actions, [128], representation_dim, activation_name="elu")
        else:
            self.input_dim_actions = 0
            print("No action encoder for pretraining, this model now is a series forecasting model for states. You should not see this normally.")
        self.state_encoder = build_mlp(input_dim_states, [128], representation_dim, activation_name="elu")
        # self.main_MLP = build_mlp(2*representation_dim*input_timesteps, hidden_dims, penultimate_dim, activation_name="elu")
        self.main_GRU = nn.GRU(input_size=2*representation_dim, hidden_size=penultimate_dim, num_layers=1, batch_first=False)

        self.output_layer = nn.Linear(penultimate_dim, output_dim)
        self.hidden_states = None
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path
                # only instantiate the action encoder for RL when loading weights
                self.action_encoder_RL = build_mlp(input_obs_dim_RL, RL_crown_dims, representation_dim, activation_name="elu")
                self.action_encoder_RL.requires_grad_(True)
                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")

    def freeze_pretrained_weights(self):
        self.state_encoder.requires_grad_(False)
        self.action_encoder_pretrain.requires_grad_(False)
        self.main_GRU.requires_grad_(False)
          
    def forward_RL(self, input, masks=None, hidden_states=None):
        """
            To circumvent the problem of info from pseudo action sequence is preserved instead of real action sequence, 
            we define the hidden state of that of after t-2 step. 
            And during this forward pass, 
                (1) the pair of s_{t-1} and a_{t-1} is fed into the LSTM, 
                (2) the hidden state is saved,
                (3) the pair of s_{t} and pseudo a_{t} is fed into the LSTM
        """
        batch_mode = masks is not None
        if batch_mode:
            # batch mode (policy update): need saved hidden states.
            """
            observations: [T, num_seq, dim]
            masks: [T, num_seq]
            hidden_states: (h, c), where h and c are [1, num_seq, hidden_size]

            out: [T, num_envs, hidden_size]
            """
            if hidden_states is None:
                raise ValueError("Hidden states not passed to memory module during policy update")

            s_tm1, s_t, a_tm1 = extract_state_action_input_for_dynamics_module(input) # [T, num_seq, D]
            embedded_s_tm1, embedded_s_t = self.state_encoder(s_tm1), self.state_encoder(s_t) 
            embedded_a_tm1 = self.action_encoder_pretrain(a_tm1)
            pseudo_a_t_embedding = self.action_encoder_RL(input[..., self.input_slice_policy.to(input.device)]) # [T, num_seq, D]
            # until tm1 step
            input_rnn_tm1 = torch.cat((embedded_s_tm1, embedded_a_tm1), dim=-1) # [T, num_seq, hidden_size]  
            out_rnn_tm1, _ = self.main_GRU(input_rnn_tm1, hidden_states) # [T, num_seq, hidden_size]    
            # t step
            # note that only if the rnn relies only ONE hidden state vector (not like LSTM, which requires both h and c)
            # all hidden state info is retained in out tensor, so we can reconstruct the hidden state vector. 
            T, N, D = out_rnn_tm1.shape
            hidden_states_t = rearrange(out_rnn_tm1, "T N D -> 1 (T N) D")
            input_rnn_t = rearrange(torch.cat((embedded_s_t, pseudo_a_t_embedding), dim=-1), "T N D -> 1 (T N) D")
            out_rnn_t, _ = self.main_GRU(input_rnn_t, hidden_states_t) # [1, (T N), D]
            out_rnn = rearrange(out_rnn_t, "1 (T N) D -> T N D", T=T, N=N, D=D)

            out_rnn = unpad_trajectories(out_rnn, masks) # [T, num_envs, hidden_size]
            out_pseudo_a_t_embedding = unpad_trajectories(pseudo_a_t_embedding, masks) # [T, num_envs, representation_dim]
            out = torch.cat([out_pseudo_a_t_embedding, out_rnn], dim=-1) # [T, num_envs, representation_dim + hidden_size]
        else:
            # inference mode (collection): use hidden states of last step
            """
            observations: [num_envs, dim]
            masks: None
            hidden_states: None

            out: [num_envs, hidden_size]
            """
            s_tm1, s_t, a_tm1 = extract_state_action_input_for_dynamics_module(input) # [B, 21]
            embedded_s_tm1, embedded_s_t = self.state_encoder(s_tm1), self.state_encoder(s_t)  # [B, representation_dim]
            embedded_a_tm1 = self.action_encoder_pretrain(a_tm1)
            pseudo_a_t_embedding = self.action_encoder_RL(input[:, self.input_slice_policy.to(input.device)]) # [B, representation_dim]
            # tm1 step
            input_rnn_tm1 = torch.cat((embedded_s_tm1, embedded_a_tm1), dim=-1) # [B, representation_dim*2]
            _, self.hidden_states = self.main_GRU(input_rnn_tm1.unsqueeze(0), self.hidden_states)
            # t step 
            input_rnn_t = torch.cat((embedded_s_t, pseudo_a_t_embedding), dim=-1)
            out_rnn_t, _ = self.main_GRU(input_rnn_t.unsqueeze(0), self.hidden_states)

            out = torch.cat([pseudo_a_t_embedding, out_rnn_t.squeeze(0)], dim=-1)
        return out
    

    def forward(self, states, actions):
        """
            args:
                states: [batch_size, input_timesteps, input_dim_states]
                actions: [batch_size, input_timesteps, input_dim_actions]
        """
        embedded_states = self.state_encoder(states)
        if self.input_dim_actions != 0:
            embedded_actions = self.action_encoder_pretrain(actions)
            x = torch.cat((embedded_states, embedded_actions), dim=-1) # [B, input_timesteps, 2*representation_dim]
        else:
            x = embedded_states
        x, _ = self.main_GRU(x.transpose(0, 1))
        x = self.output_layer(x[-1])
        return x # Return the last output for the last timestep, which is the prediction of the next state
    
    def forward_traj(self, states, actions):
        """
            args:
                states: [batch_size, input_timesteps, input_dim_states]
                actions: [batch_size, input_timesteps, input_dim_actions]
        """
        embedded_states = self.state_encoder(states)
        if self.input_dim_actions != 0:
            embedded_actions = self.action_encoder_pretrain(actions)
            x = torch.cat((embedded_states, embedded_actions), dim=-1) # [B, input_timesteps, 2*representation_dim]
        else:
            x = embedded_states
        x, _ = self.main_GRU(x.transpose(0, 1))
        x = self.output_layer(x)
        return x.transpose(0, 1) # [batch_size, input_timesteps, output_dim]
    
    def reset(self, dones=None):
        # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
        if self.hidden_states is None:
            return
        for hidden_state in self.hidden_states:
            hidden_state[..., dones == 1, :] = 0.0

    def get_hidden_states(self):
        return self.hidden_states



@configclass
class DynamicsSubmoduleConfigRNN:
    class_name: str = "DynamicGRU"
    """The class name of the submodule."""

    input_dim_states: int = MISSING

    input_dim_actions: int = MISSING

    input_timesteps: int = MISSING
    """The input dimension of the submodule. 12 for quadruped since we have 4 legs with 3 joints each."""

    output_dim: int = MISSING
    """The output dimension of the submodule. 6 for pose, 3 for translation."""

    hidden_dims: list[int] = [512, 256, 128]
    """The hidden dimension of the MLP."""

    RL_crown_dims: list[int] = [128, 128]
    """The hidden dimension of the MLP for RL. This should be set in RL config instances."""

    representation_dim: int = 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""

    input_obs_dim_RL: int | None = None
    """The input dimension of the observation tensor in RL stage. This should be set in RL config instances."""

    backbone_output_dim: int = 128 + 128
    """The output dimension of the backbone MLP. This should not be too large to avoid enlarging the input of final MLPs in actor-critic."""
    
    weight_path: str | None = None
    """The path to the pretrained model weights. If None, the model will be initialized from scratch."""

    mode: str = "joints_only" 
    """The mode of the model. Can be "joints_only" or "joints_plus_base". This does not affect training, only affects the inference during RL."""

    finetune_frozen: bool = False
    """Whether to freeze the pretrained weights during finetuning. If True, the pretrained weights will not be updated during finetuning."""