import logging
from turtle import pos, position
import torch
from torch import mode, optim, nn, Tensor
from torchmetrics import MeanMetric
import lightning as L
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from src.multibody_sim.ground_contact import contact_points_2d, sliding_contact_point
from src.utils.utils import GLOBAL_CFG, really_safe_normalise_in_place
from src.utils.metabolics import metabolic_cost_kimr15

from src.kinematics.kinematics_2d import *
import pickle

from src.multibody_sim.equations_of_motion import (
    equation_of_motions,
    Kmatrix_loss_moveEst,
)

class SspinnLitModule(L.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        scheduler: OptimizerLRScheduler,
        criterion,
        input_noise,
        loss_weights,
        input_variables,
        estimated_variables,
        loss_d_variables,
        optimize_constants=None,
    ):
        super().__init__()
        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(ignore=["model", "criterion"], logger=False)

        #self.model = torch.compile(model)
        self.model = model
        self.optimizer = optimizer

        self.scheduler = scheduler
        self.criterion = criterion
        self.loss_weights = loss_weights
        self.noise = input_noise

        self.input_variables = input_variables
        self.estimated_variables = estimated_variables
        self.loss_d_variables = loss_d_variables

        self.train_loss = MeanMetric()
        self.train_loss_r = MeanMetric()
        self.train_loss_t = MeanMetric()
        self.train_loss_p = MeanMetric()
        self.train_loss_torques = MeanMetric()
        self.train_loss_gc = MeanMetric()
        self.train_loss_sliding = MeanMetric()
        self.train_loss_mee = MeanMetric()
        self.train_loss_foot_speed = MeanMetric()
        self.train_loss_sym = MeanMetric()

        self.val_loss = MeanMetric()
        self.val_loss_r = MeanMetric()
        self.val_loss_t = MeanMetric()
        self.val_loss_p = MeanMetric()
        self.val_loss_torques = MeanMetric()
        self.val_loss_gc = MeanMetric()
        self.val_loss_sliding = MeanMetric()
        self.val_loss_mee = MeanMetric()
        self.val_loss_foot_speed = MeanMetric()
        self.val_loss_sym = MeanMetric()

        self.freeze_model = False

        if optimize_constants is not None:
            if optimize_constants.run is False:
                self.constant_optimization = False
            else:
                self.optimize_constants = optimize_constants.constants
                self.constant_optimizer = optimize_constants.optimizer
                self.init_constants = False #Initialize the constants with the first batch
                self.constant_optimization = True
                if optimize_constants.freeze_model:
                    for param in self.model.parameters():
                        param.requires_grad = False
        else:
            self.constant_optimization = False

    def init_constants_optimization(self, batch: dict[str, Tensor]):
        # Initialize the constants with the first batch
        constant_dict = {}
        for var in self.optimize_constants:
            constant_dict[var] = batch[var][0:1, 0:1, :].clone() # Only take one frame, zero_slice
            # Make sure the constant is trainable
            constant_dict[var].requires_grad = True

        # Initialize the constant optimizer
        lr = self.constant_optimizer.lr
        decay = self.constant_optimizer.weight_decay

        self.constants = constant_dict
        if self.constant_optimizer.target == "Adam":
            self.constant_optimizer = optim.Adam(constant_dict.values(), lr=lr, weight_decay=decay)
        else: # Default to SGD
            self.constant_optimizer = optim.SGD(constant_dict.values(), lr=lr, weight_decay=decay)

    def get_batch(self, batch: dict[str, Tensor]):
        # Get the batch for the constant optimization
        for var in self.optimize_constants:
            if var == 'body_constants':
                # unleafify the tensor - no optimization of the inertia and gravity
                vars_to_optim = [0,1,2,4,5,6,8,9,10,13] # Only consider masses and lengths and coms but not torso com
                batch[var][:,:,vars_to_optim] = self.constants[var][:,:,vars_to_optim] * torch.ones_like(batch[var])[:,:,vars_to_optim]
            elif var == 'imu_offsets':
                vars_to_optim = [2,3,4,5,6,7,8,9,10,11,12,13] # No pelvis IMU
                batch[var][:,:,vars_to_optim] = self.constants[var][:,:,vars_to_optim] * torch.ones_like(batch[var])[:,:,vars_to_optim]
            elif var == 'imu_rotations':
                vars_to_optim = [0,1,2,4,5] # No foot IMUs
                batch[var][:,:,vars_to_optim] = self.constants[var][:,:,vars_to_optim] * torch.ones_like(batch[var])[:,:,vars_to_optim]
            else:
                batch[var] = self.constants[var] * torch.ones_like(batch[var])
        return batch


    def calculate_time_loss(self, IK_pred: Tensor, fps: float = 100.0, euler = "BE", mode = "kinematics", positions = None):
        # time based loss for predicted derivatives (eg. a_pelvis -> da_pelvis -> dda_pelvis)
        # eg. for a_pelvis:
        # taking the derivative of a_pelivs should be the same as da_pelvis
        # taking the derivative of da_pelvis should be the same as dda_pelvis
        # If mode == positions, then we check the positions instead of the velocities
        loss_t = 0
        if mode == "kinematics" or mode == "both":
            IK_pred_diff = torch.diff(IK_pred, dim=1) * fps
            if euler == "BE":
                loss_pos_vel = IK_pred_diff[:, :, 0::3] - IK_pred[:, 1:, 1::3]
                loss_vel_acc = IK_pred_diff[:, :, 1::3] - IK_pred[:, 1:, 2::3]
            elif euler == "FE":
                loss_pos_vel = IK_pred_diff[:, :, 0::3] - IK_pred[:, :-1, 1::3]
                loss_vel_acc = IK_pred_diff[:, :, 1::3] - IK_pred[:, :-1, 2::3]
            else:
                raise ValueError(f"Invalid Integration Rule representation: {euler} (Should be 'BE' or 'FE')")
            # Loss between tx and d_tx should be zero:
            loss_pos_vel[:, :, 0] = 0
            loss_t = self.criterion(
                loss_pos_vel / (torch.std(IK_pred[:, :, 1::3], dim=1, keepdim=True) + 1e-8),
                torch.zeros_like(loss_pos_vel, device=self.device)
            )
            loss_t += self.criterion(
                loss_vel_acc / (torch.std(IK_pred[:, :, 2::3], dim=1, keepdim=True) + 1e-8),
                torch.zeros_like(loss_vel_acc, device=self.device)
            )
        if mode == "positions" or mode == "both":
            position, gc_positions = positions
            for key in position.keys():
                if key in ['hip_r','hip_l']:
                    continue # Hip is the same position as pelvis
                pos_pred_diff = torch.diff(position[key], dim=1) * fps
                if euler == "BE":
                    loss_pos = pos_pred_diff[:, :, 0::3] - position[key][:, 1:, 1::3]
                    loss_vel = pos_pred_diff[:, :, 1::3] - position[key][:, 1:, 2::3]
                elif euler == "FE":
                    loss_pos = pos_pred_diff[:, :, 0::3] - position[key][:, :-1, 1::3]
                    loss_vel = pos_pred_diff[:, :, 1::3] - position[key][:, :-1, 2::3]
                else:
                    raise ValueError(f"Invalid Integration Rule representation: {euler} (Should be 'BE' or 'FE')")

                if key == 'pelvis':
                    loss_pos[:, :, 0] = 0

                loss_t += self.criterion(
                    loss_pos / (torch.std(position[key][:, :, 1::3], dim=1, keepdim=True) + 1e-8),
                    torch.zeros_like(loss_pos, device=self.device)
                )
                loss_t += self.criterion(
                    loss_vel / (torch.std(position[key][:, :, 2::3], dim=1, keepdim=True) + 1e-8),
                    torch.zeros_like(loss_vel, device=self.device)
                )

            for key in gc_positions.keys():
                pos_pred_diff = torch.diff(gc_positions[key], dim=1) * fps
                if euler == "BE":
                    loss_pos = pos_pred_diff[:, :, 0::3] - gc_positions[key][:, 1:, 1::3]
                    loss_vel = pos_pred_diff[:, :, 1::3] - gc_positions[key][:, 1:, 2::3]
                elif euler == "FE":
                    loss_pos = pos_pred_diff[:, :, 0::3] - gc_positions[key][:, :-1, 1::3]
                    loss_vel = pos_pred_diff[:, :, 1::3] - gc_positions[key][:, :-1, 2::3]
                else:
                    raise ValueError(f"Invalid Euler angle representation: {euler} (Should be 'BE' or 'FE')")


                loss_t += self.criterion(
                    loss_pos / (torch.std(gc_positions[key][:, :, 1::3], dim=1, keepdim=True) + 1e-8),
                    torch.zeros_like(loss_pos, device=self.device)
                ) / 2
                loss_t += self.criterion(
                    loss_vel / (torch.std(gc_positions[key][:, :, 2::3], dim=1, keepdim=True) + 1e-8),
                    torch.zeros_like(loss_vel, device=self.device)
                ) / 2 # Divide by two to not double-weight the contact points
            loss_t = loss_t / 3.5 #Divide by 3.5 so that the weight is the same as in the kinematics case and it matches the equations in the paper

        if mode == "both":
            return loss_t / 2

        return loss_t

    def calculate_physics_loss(
        self,
        IK_pred: Tensor,
        torques: Tensor,
        ground_reaction_forces: Tensor,
        center_of_pressures: Tensor,
        body_constants: Tensor,
    ):
        K_loss = Kmatrix_loss_moveEst(
            IK_pred,
            torques,
            ground_reaction_forces,
            center_of_pressures,
            body_constants,
            1,
            device=self.device,
        )
        loss_p = self.criterion(K_loss.to(self.device), torch.zeros_like(K_loss, device=self.device))
        return loss_p

    def calculate_limit_losses(self, IK_pred: Tensor):
        joint_lower_ = torch.tensor([-np.pi / 3, -np.pi / 3, -np.pi, -np.pi / 3,
                                                 -np.pi / 3, -np.pi, -np.pi / 3], device=self.device)
        joint_upper_ = torch.tensor([np.pi / 3, np.pi / 3, 0.1, np.pi / 3,
                                                np.pi / 3, 0.1, np.pi / 3], device=self.device)
        y_lower = 0
        y_upper = 2 # Can't do high jump like this

        # 6::3 are the angles, 3 is the y position
        upper_viol = torch.relu(IK_pred[:,:,6::3]-joint_upper_)
        lower_viol = torch.relu(-IK_pred[:,:,6::3]+joint_lower_)
        y_upper_viol = torch.relu(IK_pred[:,:,3]-y_upper)
        y_lower_viol = torch.relu(-IK_pred[:,:,3]+y_lower)

        loss_l_angle = self.criterion(upper_viol + lower_viol, torch.zeros_like(upper_viol, device=self.device))
        loss_l_y = self.criterion(y_upper_viol + y_lower_viol, torch.zeros_like(y_upper_viol, device=self.device))

        speed_upper_ = 10
        upper_viol_speed = torch.relu(IK_pred[:, :, 1]-speed_upper_)
        lower_viol_speed = torch.relu(-IK_pred[:, :, 1]-speed_upper_)
        loss_speed = self.criterion(upper_viol_speed + lower_viol_speed, torch.zeros_like(upper_viol_speed, device=self.device))

        return loss_l_angle + loss_l_y + loss_speed

    def calculate_grf_bounds_loss(self, grf: Tensor):
        lower_ = 0.2 * 9.81 # In N / BW - bound set to 0.2
        avg_grf_y = torch.mean(grf[:,:,1::2], dim = -2, keepdim=True)
        viol = torch.relu(lower_ - avg_grf_y)
        return self.criterion(viol, torch.zeros_like(viol, device=self.device))

    def calculate_gc_loss(self, y_hat: Tensor, gc_positions: dict[str, Tensor], reconstructed_data: dict[str, Tensor], fps, euler = "BE"):
        gc_pred = y_hat["gc_model"]

        gc_fk = torch.cat([gc_positions[key][:, :, [0, 3, 1, 4]] for key in ['r_heel', 'r_toe', 'l_heel', 'l_toe']], dim=-1)

        # Make X relative to the pelvis (local to global)
        gc_pred[:,:,[0,4,8,12]] = gc_pred[:,:,[0,4,8,12]] + reconstructed_data['pelvis'][:,:,0:1]
        loss_gc = self.criterion(gc_pred, gc_fk)
        ## Replace the values in gc_positions with the predicted values
        gc_positions['r_heel'][:, :, [0, 3, 1, 4]] = gc_pred[:, :, :4]
        gc_positions['r_toe'][:, :, [0, 3, 1, 4]] = gc_pred[:, :, 4:8]
        gc_positions['l_heel'][:, :, [0, 3, 1, 4]] = gc_pred[:, :, 8:12]
        gc_positions['l_toe'][:, :, [0, 3, 1, 4]] = gc_pred[:, :, 12:16]

        ## Add time-loss to loss_gc
        gc_pred_diff = torch.diff(gc_pred, dim=1) * fps
        if euler == "BE":
            loss_x = gc_pred_diff[:, :, 0::4] - gc_pred[:, 1:, 2::4]  # gc_pred_x is local to the pelvis, gc_pred_dx is global
            loss_y = gc_pred_diff[:, :, 1::4] - gc_pred[:, 1:, 3::4]
        elif euler == "FE":
            loss_x = gc_pred_diff[:, :, 0::4] - gc_pred[:, :-1, 2::4]
            loss_y = gc_pred_diff[:, :, 1::4] - gc_pred[:, :-1, 3::4]
        else:
            raise ValueError(f"Invalid Euler angle representation: {GLOBAL_CFG.euler} (Should be 'BE' or 'FE')")
        # Loss between tx and d_tx should be zero:
        loss_gc += self.criterion(
            loss_x / (torch.std(gc_pred[:, :, 2::4], dim=1, keepdim=True) + 1e-8), torch.zeros_like(loss_x, device=self.device)
        ) * 0.5
        loss_gc += self.criterion(
            loss_y / (torch.std(gc_pred[:, :, 3::4], dim=1, keepdim=True) + 1e-8), torch.zeros_like(loss_y, device=self.device)
        ) * 0.5

        return loss_gc, gc_positions

    def calculate_ankle_loss(self, y_hat: Tensor, batch: dict[str, Tensor], reconstructed_data: dict[str, Tensor], fps, euler = "BE"):
        gc_positions = {}
        ankle_globals = {}
        loss_gc = 0
        gc_model_columns = GLOBAL_CFG.datamodule.dataset_variables.ground_contact_model
        for gc in gc_model.keys():
            parent = gc_model[gc]["parent"]
            parent_kin_pred = {
                'ankle_r':
                    {
                        'orig': [0, 1, 2, 3, 4, 5],
                        'new': [0, 1, 3, 4, 6, 7]
                    },
                'ankle_l':
                    {
                        'orig': [6, 7, 8, 9, 10, 11],
                        'new': [0, 1, 3, 4, 6, 7]
                    }
            }
            rpdata = torch.zeros_like(reconstructed_data[parent])
            rpdata[:, :, parent_kin_pred[parent]['new']] = y_hat['gc_model'][:, :, parent_kin_pred[parent]['orig']]
            key = gc_model[gc]["key"]
            joint_offset = batch["ground_contact_model"][:, :,
                           gc_model_columns.index(f"{key}_x"):gc_model_columns.index(f"{key}_y") + 1]
            rpdata[:, :, 0] = rpdata[:, :, 0] + reconstructed_data['pelvis'][:, :, 0]
            ankle_globals[parent] = rpdata
            gc_positions[gc] = get_joint_position(rpdata, torch.zeros_like(reconstructed_data[parent]),
                                                  joint_offset, self.device)

            loss_gc += self.criterion(rpdata[:, :, parent_kin_pred[parent]['new']],
                                      reconstructed_data[parent][:, :, parent_kin_pred[parent]['new']])
            ank_pred_diff = torch.diff(rpdata, dim=1) * fps
            if euler == "BE":
                error_ = ank_pred_diff[:, :, 0::3] - rpdata[:, 1:, 1::3]
            elif euler == "FE":
                error_ = ank_pred_diff[:, :, 0::3] - rpdata[:, :-1, 1::3]
            else:
                raise ValueError(f"Invalid Euler angle representation: {euler} (Should be 'BE' or 'FE')")

            loss_gc += self.criterion(
                error_ / (torch.std(rpdata[:, :, 1::3], dim=1, keepdim=True) + 1e-8),
                torch.zeros_like(error_, device=self.device)
            )

        # Extract 2 more dimensions from the ground contact model as the learned_mu
        learned_mu = torch.tanh(y_hat['gc_model'][:, :, -2:])
        return loss_gc, gc_positions, ankle_globals, learned_mu

    def forward(self, data: dict) -> dict[str, Tensor]:
        x = torch.cat([data[var] for var in self.input_variables], dim=-1)
        y_hat = self.model(x)
        estimation = {}

        prev_idx = 0
        # split the output into the estimated variables
        for k, v in self.estimated_variables.items():
            estimation[k] = y_hat[:, :, prev_idx : prev_idx + len(v)]
            prev_idx += len(v)

        return estimation

    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so we need to make sure val doesn't store accuracy from these checks
        self.val_loss.reset()
        self.val_loss_r.reset()
        self.val_loss_t.reset()
        self.val_loss_p.reset()
        # Early stopping needs to be reset in case we continue training
        self.trainer.early_stopping_callback.wait_count = 0
        self.trainer.early_stopping_callback.best_score = torch.tensor([torch.inf])


    def model_step(self, batch: dict, val: bool = False):
        if self.constant_optimization:
            """Optimize the constants
                   Idea for the logic: On step 0, we only initialize the constants, 
                   then the gradient should retain accross steps, so that we can call
                   the optimizer.step() before the actual model step.
                   For the body constants, the weight should sum to 1. If the inertia converging to 0, we shouldn't optimize that
            """
            if self.init_constants == True:
                self.constant_optimizer.step()
            else:
                self.init_constants_optimization(batch)
                self.init_constants = True
            batch = self.get_batch(batch)
            self.constant_optimizer.zero_grad()

            # pickle constants at the logdir
            with open(self.logger.log_dir + "/constants.pkl", "wb") as f:
                cons = {k: v.detach().cpu().numpy() for k, v in self.constants.items()}
                pickle.dump(cons, f)

        if self.loss_weights.wsym <= 0:
            return self.get_losses(batch, val)
        """
            If the symmetric conditioning is enabled, we need to mirror the batch and calculate distances between mirrored mirror-outputs and original outputs
            The aim is to enforce the model to not keep behaviours such as: "always do the right step faster"
        """
        out_orig = self.get_losses(batch, val)
        symmetry_idx_input = {
            "IMU_data": [*range(3),*range(3*3+3,6*3+3),*range(0*3+3,3*3+3)], # 1st is pelvis, then 3 right, 3 left
            "body_constants": [*range(16)], # Symmetric body assumptions
            "ground_contact_model": [*range(10)], # Symmetric ground contact model
            "imu_offsets": [*range(2),*range(3*2+2,6*2+2),*range(0*2+2,3*2+2)], # 1st is pelvis, then 2 right, 2 left
            "speed": [1, 0], # Swap the foot speeds
        }
        symmetry_idx_output = {
            'IK_data': [*range(9),*range(3*3+9,6*3+9),*range(0*3+9,3*3+9)],
            'torques': [*range(3,6), *range(0,3)],
            'gc_model': [*range(8,16), *range(0,8)] if GLOBAL_CFG.gc_ss_level == 'cps' else [*range(6,12),*range(0,6),-3,-4,-1,-2],
        }
        batch_sym = {}
        for key in batch.keys():
            batch_sym[key] = batch[key][:, :, symmetry_idx_input[key]].clone().detach() # CUDA assert triggered if not detached

        out_sym = self.get_losses(batch_sym, val)
        loss_sym = 0
        yhat_orig = out_orig[-1]
        yhat_sym = out_sym[-1]
        for key in yhat_orig.keys():
            l_sym = (yhat_orig[key] - yhat_sym[key][:, :, symmetry_idx_output[key]])/torch.std(yhat_orig[key], dim=1, keepdim=True)
            loss_sym += self.criterion(l_sym, torch.zeros_like(l_sym, device=self.device))

        out_orig = list(out_orig)
        # Total Loss, 0.5 to adjust total weighting
        out_orig[0] = 0.5 * out_orig[0] + 0.5 * out_sym[0] + loss_sym * self.loss_weights.wsym
        out_orig[-2] += loss_sym # Append the symmetry loss for logging
        return tuple(out_orig)

    def get_losses(self, batch: dict, val=False):
        # Get sparsity
        sparsity = [3*i + j for i in GLOBAL_CFG.sparsity for j in range(3)]
        bd2 = batch['IMU_data'].clone()
        bd2[:, :, sparsity] = 0

        if self.noise > 0 and not val:
            # Find std of the data and add noise with std = std * noise
            std = torch.std(bd2, dim=-2, keepdim=True)
            noise = torch.randn_like(bd2) * std * self.noise

            y_hat = self.forward({
                "IMU_data": bd2 + noise,
                "body_constants": batch["body_constants"] + torch.randn_like(batch["body_constants"]) * 0.00,
                "imu_offsets": batch["imu_offsets"] + torch.randn_like(batch["imu_offsets"]) * 0.00,
                "imu_rotations": batch["imu_rotations"] + torch.randn_like(batch["imu_rotations"]) * 0.00,
                "ground_contact_model": batch["ground_contact_model"] + torch.randn_like(batch["ground_contact_model"]) * 0.00,
            })
        else:
            y_hat = self.forward(
                {
                    "IMU_data": bd2,
                    "body_constants": batch["body_constants"],
                    "imu_offsets": batch["imu_offsets"],
                    "imu_rotations": batch["imu_rotations"],
                    "ground_contact_model": batch["ground_contact_model"],
                } # Never put in IMU speed as an information
            )

        IK_pred = y_hat["IK_data"]
        # Replace the total x position with the cumulative sum of the x velocity.
        y_hat["IK_data"][:, :, 0] = torch.cumsum(y_hat['IK_data'][:, :, 1], dim=1) / GLOBAL_CFG.fps

        # Functional joint range limits for unambiguous joint angle predictions [-pi, pi]
        loss_l = self.calculate_limit_losses(IK_pred)

        reconstructed_data, imu_data, gc_positions, ankle_imu_global = global_kinematics(
            IK_pred,
            batch["body_constants"],
            batch["imu_offsets"],
            batch["imu_rotations"],
            batch["ground_contact_model"],
            GLOBAL_CFG,
            device=self.device,
        )

        # get the ankle speed from the kinematics
        v_ankle_sim = torch.cat([ankle_imu_global['ankle_r'][:, :, 1:2], ankle_imu_global['ankle_l'][:, :, 1:2]], dim=-1)
        loss_foot_speed = torch.abs(v_ankle_sim - batch['speed'])
        loss_foot_speed = torch.relu(loss_foot_speed - 0.3 * torch.max(batch['speed'], keepdim=True, dim=-2).values) # 30% of the speed as error is allowed -> no exact guidance
        loss_foot_speed = self.criterion(loss_foot_speed, torch.zeros_like(loss_foot_speed, device=self.device))


        ## NN ground contact model loss, only use if wgc > 0 (is specified)
        if self.loss_weights.wgc > 0 and GLOBAL_CFG.gc_ss_level == 'cps':
            loss_gc, gc_positions = self.calculate_gc_loss(y_hat, gc_positions, reconstructed_data, GLOBAL_CFG.fps, GLOBAL_CFG.euler)
        elif self.loss_weights.wgc > 0 and GLOBAL_CFG.gc_ss_level == 'ankle':
            loss_gc, gc_positions, ankle_globals, learned_mu = self.calculate_ankle_loss(y_hat, batch, reconstructed_data, GLOBAL_CFG.fps, GLOBAL_CFG.euler)
            if GLOBAL_CFG.ankle_imu_position == 'ss_foot':
                # reconstruct the ankle IMU data based on the self-supervised foot position
                for foot, imu_idx in zip(['ankle_r','ankle_l'],[3,6]):
                    foot_globals = torch.zeros_like(ankle_globals[foot])
                    foot_globals[:,:,0::3] = ankle_globals[foot][:,:,0::3]
                    foot_globals[:,:,1::3] = ankle_globals[foot][:,:,1::3]
                    foot_globals[:,:,2::3] = reconstructed_data[foot][:,:,2::3]
                    imu_ = get_joint_position(foot_globals, torch.zeros_like(ankle_globals[foot][:,:,-3:]), batch['imu_offsets'][:,:,2*imu_idx:2*imu_idx+2], self.device)
                    imu_data[:,:,3*imu_idx:3*imu_idx+3] = to_local_coordinates_imu(imu_, batch['body_constants'][:,:,-1], self.device) # Same as in global reconstruction
        else:
            loss_gc = 0
            _,_,_, learned_mu = self.calculate_ankle_loss(y_hat, batch, reconstructed_data, GLOBAL_CFG.fps, GLOBAL_CFG.euler)


        # Reconstruction Loss, normalize by sequence variance
        imu_batch = batch["IMU_data"]/torch.std(batch["IMU_data"], dim=1, keepdim=True)
        imu_data = imu_data/torch.std(batch["IMU_data"], dim=1, keepdim=True)
        imu_batch[:,:,sparsity] = 0
        imu_data[:,:,sparsity] = 0
        loss_r = self.criterion(imu_data, imu_batch)

        loss_t = self.calculate_time_loss(IK_pred, GLOBAL_CFG.fps, GLOBAL_CFG.euler, mode=GLOBAL_CFG.euler_mode, positions=(reconstructed_data, gc_positions))

        # Get GC model forces and CoP
        if GLOBAL_CFG.gc_model == 'cp':
            grf, moment, grf_cps = contact_points_2d(
                gc_positions,
                batch["ground_contact_model"],
                gc_model,
                GLOBAL_CFG,
                device = self.device,
                get_moments=True,
                kinematics=reconstructed_data
            )
            # todo: note: sliding might be a side-effect of favoring certain strategies (see symmetric conditioning or torque loss) for a possible fix
            # Auxiliary gc loss to discourage sliding
            grf_y = grf[:, :, [1, 1, 3, 3]]  # grf_r, grf_r, grf_l, grf_l
            cp_v_x = torch.cat([  # 1:2 are the x velocities in a 9-dim tensor
                gc_positions['r_heel'][:, :, 1:2],
                gc_positions['r_toe'][:, :, 1:2],
                gc_positions['l_heel'][:, :, 1:2],
                gc_positions['l_toe'][:, :, 1:2],
            ], dim=-1)
            sliding_penalty = torch.abs(
                torch.relu(grf_y) * torch.relu(torch.abs(cp_v_x) - 0.15))  # Allow SLOW sliding (0.15 m/s)
            loss_sliding = self.criterion(sliding_penalty, torch.zeros_like(sliding_penalty, device=self.device))
        elif GLOBAL_CFG.gc_model == 'sliding':
            mu_ = "learned"
            grf, moment, grf_cps, cp_mix_ = sliding_contact_point(
                gc_positions,
                batch["ground_contact_model"],
                gc_model,
                GLOBAL_CFG,
                device = self.device,
                get_moments=True,
                kinematics=reconstructed_data,
                learned_mu = (learned_mu if mu_ == "learned" else None)
            )
            grf_y = grf[:, :, 1::2]
            cp_v_x = torch.cat([
                cp_mix_['ankle_r'][:, :, 1:2],
                cp_mix_['ankle_l'][:, :, 1:2],
            ], dim=-1)
            if mu_ == "learned":
                sliding_penalty = torch.abs(
                    torch.relu(grf_y) * torch.relu(torch.abs(cp_v_x))) # Allow NO sliding
                l_mu_jerk = 0 #torch.diff(learned_mu, dim=1, n=3)
            else:
                sliding_penalty = torch.abs(
                    torch.relu(grf_y) * torch.relu(torch.abs(cp_v_x) - 0.15)) # Allow SLOW sliding (0.15 m/s)
                l_mu_jerk = 0
            loss_sliding = self.criterion(sliding_penalty, torch.zeros_like(sliding_penalty, device=self.device))
        loss_grf = self.calculate_grf_bounds_loss(grf) #+ torch.mean(torch.abs(l_mu_jerk))

        # Kane's Loss
        loss_p = self.calculate_physics_loss(
            IK_pred,
            y_hat["torques"],
            grf,
            moment,
            batch["body_constants"],
        )

        # todo: Note on symmetric regularization that I could implement:
        """
            The model would often try to learn strategies such as:
            - Always support with the right leg and swing with the left leg
            - Always support with the left leg and swing with the right leg
            
            We want to discourage this behavior, so we add a loss term that encourages the model to treat both sides equally. 
            We can do this by symmetric conditioning:
            * For the same batch, flip the input data and compute the estimation and the flipped estimation
            * min(MSE(estimation, flipped_estimation)) to force network to generalize symmetry
        """

        # Minimize Joint Torques, with speed weighting if speed > 1 - ref to dorschky, but undesired for standing
        speed = torch.mean(torch.abs(IK_pred[:, :, 1:2]), dim=1, keepdim=True)
        speed_mask = torch.relu(speed - 1) + 1
        loss_torques = self.criterion(y_hat["torques"]/speed_mask, torch.zeros_like(y_hat["torques"], device=self.device))

        mee = metabolic_cost_kimr15(IK_pred, y_hat["torques"])
        loss_mee = self.criterion(mee, torch.zeros_like(mee, device=self.device))

        # Add the body constants to the loss
        if self.constant_optimization:
            if 'body_constants' in self.optimize_constants:
                # Body weight should be 1
                lbw = torch.abs(torch.sum(self.constants['body_constants'][:, :, [2, 2, 6, 6, 10, 10, 13]])-1)
                loss_constants = 1e3*self.criterion(lbw, torch.zeros_like(lbw, device=self.device))
            else:
                loss_constants = 0
        else:
            loss_constants = 0

        loss = (
            + loss_r * self.loss_weights.wr
            + loss_t * self.loss_weights.wt
            + loss_p * self.loss_weights.wp
            + loss_l * self.loss_weights.wl
            + loss_grf * self.loss_weights.wgrf
            + loss_torques * self.loss_weights.wtorques
            + loss_gc * self.loss_weights.wgc
            + loss_sliding * self.loss_weights.wsliding
            + loss_mee * self.loss_weights.wmee
            + loss_foot_speed * self.loss_weights.wfs
            + loss_constants
        )

        return (
            loss,
            loss_r,
            loss_t,
            loss_p,
            loss_torques,
            loss_gc,
            loss_mee,
            loss_sliding,
            loss_foot_speed,
            0,
            y_hat,
        ) # We don't need to monitor bounds and constants, as they are practically enforced

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        self.model.train()
        (
            loss,
            loss_r,
            loss_t,
            loss_p,
            loss_torques,
            loss_gc,
            loss_mee,
            loss_sliding,
            loss_foot_speed,
            loss_sym,
            y_hat,
        ) = self.model_step(batch)

        self.train_loss_r(loss_r)
        self.train_loss_t(loss_t)
        self.train_loss_p(loss_p)
        self.train_loss_torques(loss_torques)
        self.train_loss_gc(loss_gc)
        self.train_loss_sliding(loss_sliding)
        self.train_loss_mee(loss_mee)
        self.train_loss_foot_speed(loss_foot_speed)
        self.train_loss_sym(loss_sym)
        self.train_loss(loss)

        on_step = False

        # logs metrics for each training_step,
        # and the average across the epoch, to the progress bar and logger
        for name, values in zip(["r", "t", "p", "torque", "gc", "sym","mee", "fs", "slid"],
                                [self.train_loss_r, self.train_loss_t,
                                 self.train_loss_p, self.train_loss_torques,
                                 self.train_loss_gc, self.train_loss_sym,
                                 self.train_loss_mee, self.train_loss_foot_speed,
                                 self.train_loss_sliding]):
            self.log(
                f"train_loss/mse_{name}",
                values,
                on_step=on_step,
                on_epoch=True,
                prog_bar=False,
                logger=True,
            )
        self.log(
            "train_loss/total_loss",
            self.train_loss,
            on_step=on_step,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return {
            "loss": loss,
            "y_hat": y_hat,
        }

    def on_train_epoch_end(self) -> None:
        pass

    def validation_step(self, batch, batch_idx) -> STEP_OUTPUT:
        self.model.eval()
        (
            loss,
            loss_r,
            loss_t,
            loss_p,
            loss_torques,
            loss_gc,
            loss_mee,
            loss_sliding,
            loss_foot_speed,
            loss_sym,
            y_hat,
        ) = self.model_step(batch, val = True)

        # update and log metrics
        self.val_loss_r(loss_r)
        self.val_loss_t(loss_t)
        self.val_loss_p(loss_p)
        self.val_loss_torques(loss_torques)
        self.val_loss_gc(loss_gc)
        self.val_loss_sym(loss_sym)
        self.val_loss_mee(loss_mee)
        self.val_loss_foot_speed(loss_foot_speed)
        self.val_loss_sliding(loss_sliding)
        self.val_loss(loss)

        on_step = False

        for name, values in zip(["r", "t", "p","tor","gc","sym","mee","fs","slid"],
                                [self.val_loss_r, self.val_loss_t,
                                 self.val_loss_p, self.val_loss_torques,
                                 self.val_loss_gc, self.val_loss_sym,
                                 self.val_loss_mee, self.val_loss_foot_speed,
                                 self.val_loss_sliding]):
            self.log(
                f"val_loss/mse_{name}",
                values,
                on_step=on_step,
                on_epoch=True,
                prog_bar=False,
            )

        self.log(
            f"val_loss/total_loss",
            self.val_loss,
            on_step=on_step,
            on_epoch=True,
            prog_bar=True,
        )

        return {
            "loss": loss,
            "y_hat": y_hat,
        }

    def on_validation_epoch_end(self) -> None:
        pass

    def test_step(self, batch, batch_idx) -> STEP_OUTPUT:
        (
            loss,
            loss_r,
            loss_t,
            loss_p,
            loss_torques,
            loss_gc,
            y_hat,
        ) = self.model_step(batch)

        self.test_loss_r(loss_r)
        self.test_loss_t(loss_t)
        self.test_loss_p(loss_p)
        self.test_loss_torques(loss_torques)
        self.test_loss(loss)

        for name, values in zip(["r", "t", "p", "tor"], [self.test_loss_r, self.test_loss_t, self.test_loss_p, self.test_loss_torques]):
            self.log(
                f"test_loss/mse_{name}",
                values,
                on_step=False,
                on_epoch=True,
                prog_bar=False,
                logger=True,
            )
        # log metrics
        self.log(
            f"test_loss/total_loss",
            self.test_loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
        )

        return {
            "loss_r": loss_r,
            "loss_t": loss_t,
            "loss_p": loss_p,
            "loss": loss,
            "y_hat": y_hat,
        }

    def on_test_epoch_end(self) -> None:
        pass

    def predict_step(self, batch, batch_idx, dataloader_idx=0) -> STEP_OUTPUT:
        return self.forward(batch)

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer(params=self.parameters())  # type: ignore
        if self.hparams.scheduler is not None:  # type: ignore
            scheduler = self.hparams.scheduler(optimizer=optimizer)  # type: ignore
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val_loss/total_loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}  # type: ignore


