import numpy as np
import torch

torch.autograd.set_detect_anomaly(True)

def metabolic_cost_kimr15(IK_data, torques):
    """
      KIMR is defined as: E˙=h˙M | q˙ | max | M | +h˙SL | Mq˙ | +q˙cc(Mq˙)max + Mq˙

    :param IK_data:
    :param torques:
    :return: metabolic_cost
    """
    # Constants
    h_m_ = 0.054
    h_sl_s_ = 0.283
    h_sl_l_ = 1.423
    q_cc_ = 0.004

    # idx to match IK_data with torques
    idx_IK_data_qdot = torch.arange(10,28,3)

    # Maintenance heat rate
    h_m = h_m_ * torch.max(torch.abs(IK_data[:,:,idx_IK_data_qdot]), dim=-2, keepdim=True).values * torch.abs(torques)

    # Shortening heat rate / lengthening heat rate
    h_sl_pre = torch.abs(torques*IK_data[:,:,idx_IK_data_qdot])
    h_sl_pre[h_sl_pre > 0] = torch.abs(h_sl_pre[h_sl_pre > 0]) * h_sl_s_
    h_sl_pre[h_sl_pre < 0] = torch.abs(h_sl_pre[h_sl_pre < 0]) * h_sl_l_

    # Cocontraction heat rate
    h_cc = q_cc_ * torch.max(torques*IK_data[:,:,idx_IK_data_qdot],dim=-2,keepdim=True).values

    # Work
    work = torques*IK_data[:,:,idx_IK_data_qdot]

    # Metabolic cost
    metabolic_rate = h_m + h_sl_pre + h_cc + work
    metabolic_cost = torch.sum(metabolic_rate, dim=-2) / (torch.sum(torch.abs(IK_data[:, :, 1]), dim=1,keepdim=True) + 1e-6)

    return torch.sum(metabolic_cost,dim=1)
