import torch

def contact_points_2d(gc_positions, gc_model_values, gc_model_dict, cfg, device, debug=False, get_moments = False, kinematics = None):
    """
    Calculate the ground contact forces and CoP
    :param gc_positions: The ground contact positions
    :param gc_model_values: The ground contact model values (Offset, mu, k, b)
    :param gc_model_dict: The ground contact model dictionary
    :param cfg: The global configuration
    :param device: The device
    :return: The ground contact forces and CoP
    """
    smoothing_factor_columb = cfg.gc_smoothing_factor  # Lower if learning is great
    grf_columns = cfg.datamodule.dataset_variables.ground_contact_model if not debug else cfg.dataset_variables.ground_contact_model
    cp_grfs = {}
    for contact_point in gc_model_dict.keys():
        contact_key = gc_model_dict[contact_point]['key']
        mu = gc_model_values[:, :, grf_columns.index(f"{contact_key}_mu")]
        k = gc_model_values[:, :, grf_columns.index(f"{contact_key}_k")]
        b = gc_model_values[:, :, grf_columns.index(f"{contact_key}_b")]
        k = k * 9.81 #Convert to N/m (BW/m)
        """
            Comments about the GC model implementation:
                (1 - b * gc_positions[contact_point][:, :, 4]):
                We could use a relu  on this term to avoid negative forces,
                However, negative forces need a lot of liftoff speed, so it is not a problem and helps learning
                If we would investigate e.g. high jump, that would be problematic :D
                
                There needs to be a gradient for the ground contact model if the foot is above the ground
                Therefore, we do:
                fy = fy - torch.relu(gc_positions[contact_point][:, :, 3]) * 1e-2
                after applying force in x-direction, so that the speed (dx) doesn't get affected
                
                The original 2dc model had a slightly different (smoother) implementation, but i think that shouldn't make a difference for a first-order solver
                
                Regarding fx: using '- 1e-2 * fy * mu * gc_positions[contact_point][:, :, 1]' makes the force rise continuously with higher speeds, but we keep a gradient
        """
        fy = k * torch.nn.functional.softplus(-gc_positions[contact_point][:, :, 3], 3e2) * \
             (1 - b * gc_positions[contact_point][:, :, 4])
        fx = -fy * mu * torch.tanh(gc_positions[contact_point][:, :, 1] / smoothing_factor_columb) - \
                1e-2 * fy * mu * gc_positions[contact_point][:, :, 1]

        fy = fy - torch.relu(gc_positions[contact_point][:, :, 3]) * 1e-1
        cp_grfs[contact_point] = torch.stack([fx, fy], dim=-1).to(device)

    grf = torch.cat([cp_grfs["r_heel"], cp_grfs["r_toe"],cp_grfs["l_heel"],cp_grfs["l_toe"]], dim=-1)
    if not get_moments: # Then we return each contact point separately
        # Compile the cop from the ground contact model, this is slightly inaccurate!
        hx = grf_columns.index("heel_x")
        hy = grf_columns.index("heel_y")
        tx = grf_columns.index("toe_x")
        ty = grf_columns.index("toe_y")
        cop = gc_model_values[:,:,[hx,hy,tx,ty,hx,hy,tx,ty]].to(device)

        # compile the grf
        return grf, cop

    else: # Then we return the moments & forces at the ankle
        # In this case, only one GRF per foot is returned
        grf_m = torch.cat([cp_grfs["r_heel"] + cp_grfs["r_toe"], cp_grfs["l_heel"] + cp_grfs["l_toe"]], dim=-1)
        assert kinematics is not None, "Kinematics need to be provided to calculate moments"
        m = {}
        for contact_point in gc_model_dict.keys():
            parent_key = gc_model_dict[contact_point]['parent']
            d = torch.cat([
                -(gc_positions[contact_point][:, :, 3:4] - kinematics[parent_key][:, :, 3:4]), # zero-slicing to keep the dim
                gc_positions[contact_point][:, :, 0:1] - kinematics[parent_key][:, :, 0:1]
            ], dim=-1)
            m[contact_point] = torch.sum(d * cp_grfs[contact_point], dim=-1, keepdim=True)

        m = torch.cat([m["r_heel"] + m["r_toe"], m["l_heel"] + m["l_toe"]], dim=-1)
        return grf_m, m, grf

def sliding_contact_point(gc_positions, gc_model_values, gc_model_dict, cfg, device, debug=False, get_moments = False, kinematics = None, learned_mu = None):
    """
    Calculate the ground contact forces and CoP based on a single contact point sliding which is set based on the angle of the foot
    """
    smoothing_factor_columb = cfg.gc_smoothing_factor  # Lower if learning is great
    grf_columns = cfg.datamodule.dataset_variables.ground_contact_model if not debug else cfg.dataset_variables.ground_contact_model
    cp_grfs = {}
    cp_mix_ = {}
    for idx, foot in enumerate(['ankle_r', 'ankle_l']):
        angle = kinematics[foot][:, :, 6:7]
        cp_ratio = torch.tanh(angle*7)/2 + 0.5
        if foot == 'ankle_r':
            cp_heel = 'r_heel'
            cp_toe = 'r_toe'
        else:
            cp_heel = 'l_heel'
            cp_toe = 'l_toe'

        cp_mixed = cp_ratio * gc_positions[cp_heel] + (1 - cp_ratio) * gc_positions[cp_toe]
        mu = gc_model_values[:, :, 4]
        k = gc_model_values[:, :, 2]
        b = gc_model_values[:, :, 3]
        k = k * 9.81 #Convert to N/m (BW/m)

        fy = k * torch.nn.functional.softplus(-cp_mixed[:, :, 3], 3e2) * \
             (1 - b * cp_mixed[:, :, 4])
        if learned_mu is not None:
            mu = mu * learned_mu[:, :, idx]
            fx = -fy * mu
        else:
            fx = -fy * mu * torch.tanh(cp_mixed[:, :, 1] / smoothing_factor_columb) - \
                1e-2 * fy * mu * cp_mixed[:, :, 1]

        fy = fy - torch.relu(cp_mixed[:, :, 3]) * 1e-1
        cp_grfs[foot] = torch.stack([fx, fy], dim=-1).to(device)
        cp_mix_[foot] = cp_mixed

    grf_m = torch.cat([cp_grfs["ankle_r"], cp_grfs["ankle_l"]], dim=-1)
    assert kinematics is not None, "Kinematics need to be provided to calculate moments"
    if not get_moments:
        print('Warning: Sliding contact point not implemented for CoP')
        return None, None
    assert kinematics is not None, "Kinematics need to be provided to calculate moments"
    m = {}
    for foot in ['ankle_r', 'ankle_l']:
        d = torch.cat([
            -(cp_mix_[foot][:, :, 3:4] - kinematics[foot][:, :, 3:4]),
            # zero-slicing to keep the dim
            cp_mix_[foot][:, :, 0:1] - kinematics[foot][:, :, 0:1]
        ], dim=-1)
        m[foot] = torch.sum(d * cp_grfs[foot], dim=-1, keepdim=True)

    m = torch.cat([m["ankle_r"], m["ankle_l"]], dim=-1)
    return grf_m, m, grf_m, cp_mix_





if __name__ == "__main__":
    pass