import lightning as L
import torch
import os
from torch.utils.data import Dataset, RandomSampler, DataLoader
import h5py
import warnings
from torch.utils.data import random_split
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP
from lightning.pytorch.profilers import SimpleProfiler
import wandb
from lightning.pytorch.loggers import WandbLogger
import time
# from torch.autograd.functional import jacobian
from datetime import datetime
from typing import List, Sequence
import re

JOINT_NAMES = ['LF_HAA', 'LH_HAA', 'RF_HAA', 'RH_HAA', 'LF_HFE', 'LH_HFE', 'RF_HFE', 'RH_HFE', 'LF_KFE', 'LH_KFE', 'RF_KFE', 'RH_KFE']


def make_inv_input(x, a):
    x_cut = x[:, :-1]
    a_cut = a[:, 1:-1]
    delta_desired = x[:, -1, :] - x[:, -2, :] 
    a_target = a[:, -1, :]
    return x_cut, a_cut, delta_desired, a_target

def make_fwd_input(x, a):
    x_cut = x[:, :-1]
    a_cut = a[:, 1:]
    x_delta_target = x[:, -1, :] - x[:, -2, :] 
    return x_cut, a_cut, x_delta_target


# resolve_matching_names copied from "import isaaclab.utils.string as string_utils" to avoid problematic import without starting isaaclab

def resolve_matching_names(
    keys: str | Sequence[str], list_of_strings: Sequence[str], preserve_order: bool = False
) -> tuple[list[int], list[str]]:
    """Match a list of query regular expressions against a list of strings and return the matched indices and names.

    When a list of query regular expressions is provided, the function checks each target string against each
    query regular expression and returns the indices of the matched strings and the matched strings.

    If the :attr:`preserve_order` is True, the ordering of the matched indices and names is the same as the order
    of the provided list of strings. This means that the ordering is dictated by the order of the target strings
    and not the order of the query regular expressions.

    If the :attr:`preserve_order` is False, the ordering of the matched indices and names is the same as the order
    of the provided list of query regular expressions.

    For example, consider the list of strings is ['a', 'b', 'c', 'd', 'e'] and the regular expressions are ['a|c', 'b'].
    If :attr:`preserve_order` is False, then the function will return the indices of the matched strings and the
    strings as: ([0, 1, 2], ['a', 'b', 'c']). When :attr:`preserve_order` is True, it will return them as:
    ([0, 2, 1], ['a', 'c', 'b']).

    Note:
        The function does not sort the indices. It returns the indices in the order they are found.

    Args:
        keys: A regular expression or a list of regular expressions to match the strings in the list.
        list_of_strings: A list of strings to match.
        preserve_order: Whether to preserve the order of the query keys in the returned values. Defaults to False.

    Returns:
        A tuple of lists containing the matched indices and names.

    Raises:
        ValueError: When multiple matches are found for a string in the list.
        ValueError: When not all regular expressions are matched.
    """
    # resolve name keys
    if isinstance(keys, str):
        keys = [keys]
    # find matching patterns
    index_list = []
    names_list = []
    key_idx_list = []
    # book-keeping to check that we always have a one-to-one mapping
    # i.e. each target string should match only one regular expression
    target_strings_match_found = [None for _ in range(len(list_of_strings))]
    keys_match_found = [[] for _ in range(len(keys))]
    # loop over all target strings
    for target_index, potential_match_string in enumerate(list_of_strings):
        for key_index, re_key in enumerate(keys):
            if re.fullmatch(re_key, potential_match_string):
                # check if match already found
                if target_strings_match_found[target_index]:
                    raise ValueError(
                        f"Multiple matches for '{potential_match_string}':"
                        f" '{target_strings_match_found[target_index]}' and '{re_key}'!"
                    )
                # add to list
                target_strings_match_found[target_index] = re_key
                index_list.append(target_index)
                names_list.append(potential_match_string)
                key_idx_list.append(key_index)
                # add for regex key
                keys_match_found[key_index].append(potential_match_string)
    # reorder keys if they should be returned in order of the query keys
    if preserve_order:
        reordered_index_list = [None] * len(index_list)
        global_index = 0
        for key_index in range(len(keys)):
            for key_idx_position, key_idx_entry in enumerate(key_idx_list):
                if key_idx_entry == key_index:
                    reordered_index_list[key_idx_position] = global_index
                    global_index += 1
        # reorder index and names list
        index_list_reorder = [None] * len(index_list)
        names_list_reorder = [None] * len(index_list)
        for idx, reorder_idx in enumerate(reordered_index_list):
            index_list_reorder[reorder_idx] = index_list[idx]
            names_list_reorder[reorder_idx] = names_list[idx]
        # update
        index_list = index_list_reorder
        names_list = names_list_reorder
    # check that all regular expressions are matched
    if not all(keys_match_found):
        # make this print nicely aligned for debugging
        msg = "\n"
        for key, value in zip(keys, keys_match_found):
            msg += f"\t{key}: {value}\n"
        msg += f"Available strings: {list_of_strings}\n"
        # raise error
        raise ValueError(
            f"Not all regular expressions are matched! Please check that the regular expressions are correct: {msg}"
        )
    # return
    return index_list, names_list



def _switch_anymal_joints_left_right_pretrain(joint_data: torch.Tensor, joint_names: List[str]) -> torch.Tensor:
    """Applies a front-back symmetry transformation to the joint data tensor."""
    joint_data_switched = torch.zeros_like(joint_data)
    # front <-- hind
    joint_data_switched[..., resolve_matching_names(("L._.*"), joint_names, preserve_order=True)[0]] = joint_data[..., resolve_matching_names(("R._.*"), joint_names, preserve_order=True)[0]]
    # hind <-- front
    joint_data_switched[..., resolve_matching_names(("R._.*"), joint_names, preserve_order=True)[0]] = joint_data[..., resolve_matching_names(("L._.*"), joint_names, preserve_order=True)[0]]

    # Flip the sign of the HFE and KFE joints
    joint_data_switched[..., resolve_matching_names((".*_HAA"), joint_names, preserve_order=True)[0]] *= -1

    return joint_data_switched

def _switch_anymal_joints_front_back_pretrain(joint_data: torch.Tensor, joint_names: List[str]) -> torch.Tensor:
    """Applies a front-back symmetry transformation to the joint data tensor."""
    joint_data_switched = torch.zeros_like(joint_data)
    # front <-- hind
    joint_data_switched[..., resolve_matching_names((".F_.*"), joint_names, preserve_order=True)[0]] = joint_data[..., resolve_matching_names((".H_.*"), joint_names, preserve_order=True)[0]]
    # hind <-- front
    joint_data_switched[..., resolve_matching_names((".H_.*"), joint_names, preserve_order=True)[0]] = joint_data[..., resolve_matching_names((".F_.*"), joint_names, preserve_order=True)[0]]

    # Flip the sign of the HFE and KFE joints
    joint_data_switched[..., resolve_matching_names((".*_HFE|.*_KFE"), joint_names, preserve_order=True)[0]] *= -1

    return joint_data_switched


# @configclass
# class InvInputCfg(ObsGroup):
#     """Observations for PIDM group. This space can not include past actions!

#     # observation terms (order preserved)
#     base_lin_vel = ObsTerm(func=mdp.base_lin_vel) # indices 0:3
#     base_ang_vel = ObsTerm(func=mdp.base_ang_vel) # indices 3:6
#     projected_gravity = ObsTerm(
#         func=mdp.projected_gravity,
#     ) # indices 6:9
    
#     joint_pos = ObsTerm(func=mdp.joint_pos_rel) # length 12, indices 9:21
#     joint_vel = ObsTerm(func=mdp.joint_vel_rel) # length 12, indices 21:33




def _transform_obs_left_right_pretrain(tensor: torch.Tensor) -> torch.Tensor:
    """Applies a left-right symmetry transformation to the observation tensor."""
    tensor = tensor.clone()
    device = tensor.device
    idx = 0
    # lin vel
    tensor[..., idx : idx + 3] = tensor[..., idx : idx + 3] * torch.tensor([1, -1, 1], device=device)
    idx += 3
    # ang vel
    tensor[..., idx : idx + 3] = tensor[..., idx : idx + 3] * torch.tensor([-1, 1, -1], device=device)
    idx += 3
    # projected gravity
    tensor[..., idx : idx + 3] = tensor[..., idx : idx + 3] * torch.tensor([1, -1, 1], device=device)
    idx += 3
    # joint_pos
    tensor[..., idx : idx + 12] = _switch_anymal_joints_left_right_pretrain(
        tensor[..., idx : idx + 12],
        joint_names=JOINT_NAMES,
    )
    # joint_vel
    idx += 12
    tensor[..., idx: idx + 12] = _switch_anymal_joints_left_right_pretrain(
        tensor[..., idx: idx + 12],
        joint_names=JOINT_NAMES,
    )
    assert idx + 12 == tensor.shape[-1]
    return tensor


def _transform_obs_front_back_pretrain(tensor: torch.Tensor) -> torch.Tensor:
    """Applies a left-right symmetry transformation to the observation tensor."""
    tensor = tensor.clone()
    device = tensor.device
    idx = 0
    # lin vel
    tensor[..., idx : idx + 3] = tensor[..., idx : idx + 3] * torch.tensor([-1, 1, 1], device=device)
    idx += 3
    # ang vel
    tensor[..., idx : idx + 3] = tensor[..., idx : idx + 3] * torch.tensor([1, -1, -1], device=device)
    idx += 3
    # projected gravity
    tensor[..., idx : idx + 3] = tensor[..., idx : idx + 3] * torch.tensor([-1, 1, 1], device=device)
    idx += 3
    # joint_pos
    tensor[..., idx : idx + 12] = _switch_anymal_joints_front_back_pretrain(
        tensor[..., idx : idx + 12],
        joint_names=JOINT_NAMES,
    )
    # joint_vel
    idx += 12
    tensor[..., idx: idx + 12] = _switch_anymal_joints_front_back_pretrain(
        tensor[..., idx: idx + 12],
        joint_names=JOINT_NAMES,
    )
    assert idx + 12 == tensor.shape[-1]
    return tensor

def _transform_action_left_right_pretrain(tensor: torch.Tensor) -> torch.Tensor:
    """Applies a left-right symmetry transformation to the action tensor."""
    tensor = tensor.clone()
    tensor = _switch_anymal_joints_left_right_pretrain(
        tensor,
        joint_names=JOINT_NAMES,
    )
    return tensor


def _transform_action_front_back_pretrain(tensor: torch.Tensor) -> torch.Tensor:
    """Applies a left-right symmetry transformation to the action tensor."""
    tensor = tensor.clone()
    tensor = _switch_anymal_joints_front_back_pretrain(
        tensor,
        joint_names=JOINT_NAMES,
    )
    return tensor


class DynamicSlidingWindowDataset(Dataset):
    def __init__(self, h5_path, window_size, load_into_memory=True):
        """
        Args:
            h5_path (str): Path to the HDF5 dataset.
            window_size (int): Number of timesteps in each sliding window.
            load_into_memory (bool): Whether to preload all data into memory.
                                     Default True (backward-compatible). \
                                     Should only be set to False when visualize samples of a too large dataset 
                                     on workstation to avoid memory issues.
        """
        assert window_size >= 1, "Window size must be at least 1"
        self.h5_path = h5_path
        self.window_size = window_size
        self.load_into_memory = load_into_memory

        self.data = []     # Only used if load_into_memory=True
        self.index = []    # List of (traj_id, start_idx) for memory mode or (group_name, start_idx) for disk mode

        with h5py.File(h5_path, "r") as f:
            print(f"trajectory num: {len(f)}")
            k = 0
            for group_name in f:
                group = f[group_name]
                if "inv_input" in group and "actions" in group:
                    length = group["inv_input"].shape[0]
                    if length >= window_size:
                        if load_into_memory:
                            inv_input = torch.tensor(group["inv_input"][...])
                            actions = torch.tensor(group["actions"][...])
                            self.data.append((inv_input, actions))
                            traj_ref = k
                        else:
                            traj_ref = group_name  # store group name for later disk access

                        for i in range(-window_size+1, length - (self.window_size-1)):
                            self.index.append((traj_ref, i))
                        k += 1
                del group  # Free memory for the group

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        traj_id, i = self.index[idx]

        if self.load_into_memory:
            inv_input, actions = self.data[traj_id]
        else:
            # Read only the needed window from disk
            with h5py.File(self.h5_path, "r") as f:
                group = f[traj_id]
                inv_input = torch.tensor(group["inv_input"])
                actions = torch.tensor(group["actions"])
                    
        # Apply padding for negative start indices
        if i<0:
            x = torch.cat([inv_input[0:1].repeat(abs(i), 1), inv_input[0:i + self.window_size]], dim=0)  # shape: (window_size, ...)
            a = torch.cat([actions[0:1].repeat(abs(i), 1), actions[0:i + self.window_size]], dim=0)
        else:
            x = inv_input[i:i + self.window_size]  # shape: (window_size, ...)
            a = actions[i:i + self.window_size]  # shape: (window_size - 1, ...)
        assert x.shape[0] == self.window_size and a.shape[0] == self.window_size
        return x, a

    def get_sample_entries_in_file(self, sample_num, seed=None):
        if seed is not None:
            torch.manual_seed(seed)
        indices = torch.sort(torch.randperm(len(self))[:sample_num])[0]
        samples = [self[i] for i in indices]
        x_list, a_list = zip(*samples)
        return torch.stack(x_list), torch.stack(a_list)

    def len_timesteps(self):
        if self.load_into_memory:
            return sum([x.shape[0] for x, _ in self.data])
        else:
            with h5py.File(self.h5_path, "r") as f:
                return sum([f[group]["inv_input"].shape[0] for group in f])


class INVLightningModule(L.LightningModule):
    def __init__(self, model, mode="inv", lr=1e-5, symmetry_augmentation_anymal: bool = False, save_dir_suffix: str | None=None,
                 embodiment: str = "anymal" # or "G1"
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.model: InvDynamicsMLP = model
        self.error_per_epoch = []
        self.error_accumulated = 0.0
        self.step_counter = 0
        self.mode = mode
        self.lr = lr
        self.penalize_grad = True
        self.grad_loss_beta = 10
        self.symmetry_augmentation_anymal = symmetry_augmentation_anymal
        self.embodiment = embodiment

        if self.model.dim_states == 83:
            embodiment = "G1"

        if embodiment == "anymal":
            self.noise_magnitude = torch.tensor( [0.1]*3 + [0.2]*3 + [0.05]*3 + [0.01]*12 + [1.5]*12) # make sure the order is lin_vel, ang_vel, gravity_vector, joint_angles,  joing_vels
        elif embodiment == "G1":
            self.noise_magnitude = torch.tensor( [0.1]*3 + [0.2]*3 + [0.05]*3 + [0.01]*37 + [1.5]*37) # make sure the order is lin_vel, ang_vel, gravity_vector, joint_angles,  joing_vels

        # if mode != "inv" and mode in :
        if save_dir_suffix is None:
            self.save_dir = f"logs/pretrain/lightning/inv_module_training/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
        else:
            self.save_dir = f"logs/pretrain/lightning/inv_module_training/{save_dir_suffix}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

        if save_dir_suffix is not None and "RL" in save_dir_suffix:
            self.save_intermediate = False
        else:
            self.save_intermediate = True

        os.makedirs(self.save_dir, exist_ok=True)

    def on_after_batch_transfer(self, batch, dataloader_idx):
        x, a = batch
        if self.symmetry_augmentation_anymal is True:
            idx_0, idx_1, idx_2, idx_3 = torch.arange(x.shape[0]).chunk(4) # no need to use torch.randperm, just chunk the batch into 4 parts
            new_x = torch.cat([
                x[idx_0],
                _transform_obs_left_right_pretrain(x[idx_1]),
                _transform_obs_front_back_pretrain(x[idx_2]),
                _transform_obs_left_right_pretrain(_transform_obs_front_back_pretrain(x[idx_3])),
            ], dim=0)
            new_a = torch.cat([
                a[idx_0],
                _transform_action_left_right_pretrain(a[idx_1]),
                _transform_action_front_back_pretrain(a[idx_2]),
                _transform_action_left_right_pretrain(_transform_action_front_back_pretrain(a[idx_3])),
            ], dim=0)
            return new_x, new_a
        else:
            return x, a

    def forward(self, *args):
        return self.model(*args)
    
    def on_train_epoch_start(self):
        self.error_accumulated = 0.0
        self.step_counter = 0
        if self.current_epoch % 10 == 0 and self.save_intermediate:
            self.save_model_pt(os.path.join(self.save_dir, f"epoch_{self.current_epoch:04d}.pt"))

    def save_model_pt(self, save_path):
        """
        Save the model to a .pt file.
        :param save_path: Path to save the model.
        """
        torch.save(self.model.state_dict(), save_path)
        print(f"Model saved to {save_path}")

    def on_train_end(self):
        self.save_model_pt(os.path.join(self.save_dir, "final_model.pt"))
        return super().on_train_end()

    def training_step(self, batch, batch_idx):
        loss, mean_abs_error = self.step(batch, batch_idx)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_error", mean_abs_error, on_epoch=True)
        self.error_accumulated += mean_abs_error.item()
        self.step_counter += 1
        return loss
    
    def add_noise_to_input(self, x_cut):
        return x_cut + torch.randn_like(x_cut)*self.noise_magnitude[None, None, :].to(x_cut.device)
    
    def relative_command_magnitude_scaled_loss(self, output, a_target, current_step_joint_pos):
        """Compute the relative command magnitude scaled loss."""
        # Compute the relative command magnitude

        relative_command_magnitude = torch.abs(a_target*0.5 - current_step_joint_pos) # the multiplier 0.5 need to be consistent with the scale in action cfg of the simulation env
        unscaled_loss = torch.abs(output - a_target)*0.5  
        scaled_loss = unscaled_loss / (relative_command_magnitude + 0.2)  # Avoid division by zero

        return scaled_loss.mean()  # Return the mean loss across the batch

    
    def step(self, batch, batch_idx):
        ########## only for rebuttal experiment ##########
        # if self.global_step % 50 == 0:
        #     self.save_model_pt(os.path.join(self.save_dir, f"epoch_{self.current_epoch:04d}_step_{self.global_step:06d}.pt"))
        ########## only for rebuttal experiment ##########
        if self.mode == "inv":
            x, a = batch
            x.requires_grad = True  # ensure x has gradient for inverse dynamics model

            assert not hasattr(self.model, "LSTM_core"), "deprecated feature not working currently"

            x_cut, a_cut, desired_delta_x, a_target = make_inv_input(x, a)

            output = self.model.forward_inv(self.add_noise_to_input(x_cut), a_cut, desired_delta_x)
            # output = self.model.forward_inv(x_cut, a_cut, desired_delta_x)
        
            # 1) vanilla l1 loss
            loss = torch.nn.functional.l1_loss(output, a_target)

            # 2) relative command magnitude scaled loss
            # current_step_joint_pos = x_cut[:, -1, 9:21]  # assuming joint positions are in the last 12 dimensions
            # loss = self.relative_command_magnitude_scaled_loss(output, a_target, current_step_joint_pos)

            mean_abs_error = torch.mean(torch.abs(output - a_target))

            # if self.penalize_grad and self.training:
            #     # for inverse dynamics model, it makes sense that the gradient of output action w.r.t. to the input desired
            #     # joint angle is positive, i.e., if the desired joint angle increases, the output action should also increase.
                
            #     # Compute gradient of output w.r.t. input x
            #     grad_outputs = torch.zeros_like(output) # shape: (batch_size, input_dim_actions)
            #     grad_outputs
            #     J_full = jacobian(model_output, x) 
            #     grads = torch.autograd.grad(outputs=output, inputs=desired_delta_x,
            #                                 grad_outputs=grad_outputs,
            #                                 create_graph=True, retain_graph=True)[0] # shape: (batch_size, window_size, input_dim)

            #     # Select the gradient for the specific input feature you care about
            #     grad_penalty = torch.relu(-grads[:, -1, 9:21]).mean()   # penalize when gradient < 0
                
            #     self.log("train_grad_penalty", grad_penalty, on_step=True, on_epoch=True)

            #     loss = loss + self.grad_loss_beta * grad_penalty

            return loss, mean_abs_error


        # 2) mode is not "inv"
        x, a = batch
        x.requires_grad = True  # ensure x has gradient for inverse dynamics model
        output = self.forward(x, a)

        if self.mode == "fwd":
            # delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]  # only predict joint positions
            delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 0:21]  # predict joint positions, twist and gravity vector
            loss = torch.nn.functional.l1_loss(output, delta_x_t)  # assuming output is the predicted next state
            mean_abs_error = torch.mean(torch.abs(output - delta_x_t))
        elif self.mode == "jacobian":
            jacobian, bias_state = output
            delta_pred = torch.bmm(jacobian, a[:, -1, :].unsqueeze(-1)).squeeze(-1) + bias_state
            delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]  # only predict joint positions
            loss = torch.nn.functional.l1_loss(delta_pred, delta_x_t)  # assuming output is the predicted next state
            mean_abs_error = torch.mean(torch.abs(delta_pred - delta_x_t))
        elif self.mode == "dl":  # decoupled linear
            w, b = output
            delta_pred = w*a[:, -1, :].squeeze(-1) + b
            delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]  # only predict joint positions
            loss = torch.nn.functional.l1_loss(delta_pred, delta_x_t)  # assuming output is the predicted next state
            mean_abs_error = torch.mean(torch.abs(delta_pred - delta_x_t))
        return loss, mean_abs_error
    

    def validation_step(self, batch, batch_idx):
        loss, mean_abs_error = self.step(batch, batch_idx)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("val_error", mean_abs_error, on_epoch=True)
        self.error_accumulated += mean_abs_error.item()
        self.step_counter += 1
        return loss
    
    def on_train_epoch_end(self):
        self.error_per_epoch.append(self.error_accumulated / self.step_counter if self.step_counter > 0 else 0.0)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-2)
    
    @torch.no_grad()
    def validate_batch_detail(self, batch, symmetry_left_right_transform=False, embodiment: str = "anymal"):
        """
        Validate the model on a batch of data and return the losses and magnitudes.
        :param batch: A batch of data.
        :return: A tuple of (losses, magnitudes).
        """
        x, a = batch

        if symmetry_left_right_transform:
            x = _transform_obs_left_right_pretrain(x)
            a = _transform_action_left_right_pretrain(a)
    
        if self.mode == "inv":
            # if relative command
            # a_tm1 = a[:, -1, :]
            # mean_abs_error = torch.abs(output - a_tm1)
            # target_magnitude = torch.abs(a_tm1)

            # if absolute command
            x_cut, a_cut, delta_desired, a_target = make_inv_input(x, a)
            output = self.model.forward_inv(x_cut, a_cut, delta_desired)
            mean_abs_error = torch.abs(output - a_target)*0.5
            if embodiment == "anymal":
                target_magnitude = torch.abs(a_target*0.5 - x[:, -2, 9:21])
            elif embodiment == "g1":
                target_magnitude = torch.abs(a_target*0.5 - x[:, -2, 9:9+37])
            else:
                raise NotImplementedError("Embodiment not supported for computing target magnitude in validate_batch_detail.")
        else:
            output = self.forward(x, a)
            if self.mode == "fwd":
                delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]
                mean_abs_error = torch.abs((output - delta_x_t))
                target_magnitude = torch.abs(delta_x_t)
            elif self.mode == "jacobian":
                jacobian, bias_state = output
                delta_pred = torch.bmm(jacobian, a[:, -1, :].unsqueeze(-1)).squeeze(-1) + bias_state
                delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]
                mean_abs_error = torch.abs(delta_pred - delta_x_t)
                target_magnitude = torch.abs(delta_x_t)
            elif self.mode == "dl":  # decoupled linear
                w, b = output
                delta_pred = w * a[:, -1, :].squeeze(-1) + b
                delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]
                mean_abs_error = torch.abs(delta_pred - delta_x_t)
                target_magnitude = torch.abs(delta_x_t) 

        return mean_abs_error.flatten(), target_magnitude.flatten()

# dataset = WikiText2()
# dataloader = DataLoader(dataset)
# model = LightningTransformer(vocab_size=dataset.vocab_size)

# trainer = L.Trainer(fast_dev_run=100)
# trainer.fit(model=model, train_dataloaders=dataloader)





class INVModelsEnsemble:
    def __init__(self, inv_dynamics_cfg, device):
        """
        Initialize the ensemble with a list of models.
        :param models: List of models to be included in the ensemble.
        """
        self.device = device
        self.ensemble_size = inv_dynamics_cfg["ensemble_size"]
        self.models:torch.nn.ModuleList = torch.nn.ModuleList(
            [eval(inv_dynamics_cfg["class_name"])(device=device, **inv_dynamics_cfg) for _ in range(self.ensemble_size)]
        )

    def get_intrinsic_reward(self, x, a, dones):
        # raise NotImplementedError("This method is in construction currently.")
        if self.ensemble_size == 1:
            raise NotImplementedError
            return self.get_intrinsic_reward_single_model(x_t, x_tp1, a_t, dones)
        else:
            return self.get_intrinsic_reward_ensemble(x, a, dones)

    def get_intrinsic_reward_single_for_batch(self, x, a):
        x_t, x_tp1, a_tp = x[:, 0], x[:, 1], a[:, 1]
        return self.get_intrinsic_reward_single_model(x_t.to(self.device), x_tp1.to(self.device), a_tp.to(self.device), dones=torch.zeros(x_t.shape[0]))  # dones can be None for batch processing    

    # @torch.no_grad()
    # def get_intrinsic_reward_single_model(self, x_t, x_tp1, a_t, dones):
    #     return self.models[0].get_pred_error_as_intrinsic_reward(x_t, x_tp1, a_t, dones)

    def get_intrinsic_reward_ensemble_for_batch(self, x, a):
        raise NotImplementedError("This method is in construction currently.")
        x_t, x_tp1, a_tp = x[:, 0], x[:, 1], a[:, 1]
        return self.get_intrinsic_reward_ensemble(x_t.to(self.device), x_tp1.to(self.device), dones=torch.zeros(x_t.shape[0]))  # dones can be None for batch processing

    @torch.no_grad()
    def get_intrinsic_reward_ensemble(self, x, a, dones):
        """
        For an ensemble of inverse dynamics models, compute the intrinsic reward as the disagreement among the models.
        :param x_t: Current state tensor.
        :param x_tp1: Next state tensor.
        :return: List of outputs from each model.
        """
        x_cut, a_cut, delta_desired, a_target = make_inv_input(x, a)
        a_t_pred = torch.stack([model.forward_inv(x_cut, a_cut, delta_desired) for model in self.models], dim=1) # [num_envs, ensemble_size, action_dim]
        std = torch.std(a_t_pred, dim=1).mean(-1) # [num_envs, action_dim]
        reward_mat = torch.zeros_like(std)
        reward_mat[~(dones.to(torch.bool))] = std[~(dones.to(torch.bool))] # to bool is very important!
        return reward_mat


    def check_if_data_is_sufficient_for_training(self, dataset_path):
        """
        Check if the dataset is sufficient for training the inverse dynamics model.
        :param dataset_path: Path to the dataset file.
        :return: True if the dataset is sufficient, False otherwise.
        """
        dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=2)
        print("num samples in the dataset:", len(dataset))
        samples_number_needed = 10 * self.models[0].get_number_of_trainable_parameters()
        if len(dataset) < samples_number_needed:
            warnings.warn(
                f"The number of samples in the dataset may be too small for training the inverse dynamics model! \n Number of samples: {len(dataset)}, Ratio num_samples/num_trainable_parameters: {len(dataset)/(0.1*samples_number_needed)}."
            )
            return False
        else:
            return True

    def reinitialized_and_train_model_mlp(self, model: InvDynamicsMLP, dataset: DynamicSlidingWindowDataset, epochs: int = 10, batch_size: int = 32, replacement: bool = False, save_path: str = None, wandb_log=False):
        # TODO: support more informative save_dir_suffix
        model.reinitialize_weights()
        sampler = RandomSampler(dataset, replacement=replacement, num_samples=len(dataset))
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, sampler=sampler)
        l_model = INVLightningModule(model=model, mode=model.mode, lr=1e-4, save_dir_suffix="RL")
        if wandb_log:
            wandb_logger = WandbLogger(log_model="all", project="inv_dynamics_mlp")
            trainer = L.Trainer(max_epochs=epochs, log_every_n_steps=10, gradient_clip_val=5.0, logger=wandb_logger)
        else:
            trainer = L.Trainer(max_epochs=epochs, log_every_n_steps=10, gradient_clip_val=5.0)
        trainer.fit(model=l_model, train_dataloaders=dataloader)
        if save_path is not None:
            trainer.save_checkpoint(save_path)
            print(f"Model saved to {save_path}")
        return l_model.error_per_epoch

    def retrain_models(self, dataset_path, model_save_dir, iteration_num: int = 0, epochs: int = 3, window_size: int = 2):
        # reinitialize the data every time we retrain the model, because the data samples may increase
        dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=window_size)
        print("num samples in the dataset:", len(dataset))
        full_dataset_size = len(dataset)

        samples_number_needed = 10 * self.models[0].get_number_of_trainable_parameters()

        ratio_of_training_samples = samples_number_needed / len(dataset)
        if ratio_of_training_samples < 1.0:
            dataset, _ = random_split(dataset, [ratio_of_training_samples, 1-ratio_of_training_samples]) 
        else:
            print(f"Using the full dataset for training, only {100/ratio_of_training_samples:.2f}% of the data amount recommended for training.")

        epoch_errors = []
        # reinitialize the model
        for i, model in enumerate(self.models):   
            print(f"P4RL: Re-training Inv Model... {i+1}/{self.ensemble_size}")
            error_per_epoch = self.reinitialized_and_train_model_mlp(model, dataset, epochs=epochs, batch_size=1024, replacement=True)
            model.to(self.device)  # ensure the model is on the correct device after training
            epoch_errors.append(error_per_epoch)

        self.save_models(iteration_num, model_save_dir)
        return epoch_errors, full_dataset_size
    

    def save_models(self, iteration_num, model_save_dir):
        """
        Save the models to a directory.
        :param iteration_num: The iteration number to be included in the filename.
        """
        save_dir = os.path.join(model_save_dir, f"iteration_{iteration_num:04d}")
        os.makedirs(save_dir, exist_ok=True)
        for i, model in enumerate(self.models):
            model_path = os.path.join(save_dir, f"model_{i}.pt")
            torch.save(model.state_dict(), model_path)
            print(f"Model {i} saved to {model_path}")


    def load_models(self, load_dir):
        """
        Load the models from a directory.
        :param model_save_dir: The directory where the models are saved.
        :param iteration_num: The iteration number to be included in the filename.
        """
        for i, model in enumerate(self.models):
            model_path = os.path.join(load_dir, f"model_{i}.pt")
            if os.path.exists(model_path):
                model.load_state_dict(torch.load(model_path, weights_only=True))
                print(f"Model {i} loaded from {model_path}")
            else:
                raise FileNotFoundError(f"Model file {model_path} not found.")
            
    def check_parameter_explosion(self):
        """
        Check if the model parameters are exploding.
        :return: True if the parameters are exploding, False otherwise.
        """
        for model in self.models:
            for param in model.parameters():
                if torch.isnan(param).any() or torch.isinf(param).any():
                    print("Model parameters are exploding!")
                    return True
                
        print("Model parameters are stable.")
        return False