from typing import Dict, List, Annotated
from envs import (
    AntBenchmark,
    HopperBenchmark,
    Walker2dBenchmark,
    HalfCheetahBenchmark,
    HumanoidStandUpBenchmark,
    InvertedPendulumBenchmark,
)

BENCHMARK_MUJOCO_BOUNDS = {
    "Ant": {
        "ONE_DIM_PARAMS_BOUND_3": AntBenchmark.ONE_DIM_PARAMS_BOUND_3,
        "TWO_DIM_PARAMS_BOUND_3_3": AntBenchmark.TWO_DIM_PARAMS_BOUND_3_3,
        "THREE_DIM_PARAMS_BOUND_3_3_3": AntBenchmark.THREE_DIM_PARAMS_BOUND_3_3_3,
    },
    "Halfcheetah": {
        "ONE_DIM_PARAMS_BOUND_3": HalfCheetahBenchmark.ONE_DIM_PARAMS_BOUND_3,
        "ONE_DIM_PARAMS_BOUND_4": HalfCheetahBenchmark.ONE_DIM_PARAMS_BOUND_4,
        "TWO_DIM_PARAMS_BOUND_4_7": HalfCheetahBenchmark.TWO_DIM_PARAMS_BOUND_4_7,
        "THREE_DIM_PARAMS_BOUND_3_7_4": HalfCheetahBenchmark.THREE_DIM_PARAMS_BOUND_3_7_4,
    },
    "Hopper": {
        "ONE_DIM_PARAMS_BOUND_3": HopperBenchmark.ONE_DIM_PARAMS_BOUND_3,
        "ONE_DIM_PARAMS_BOUND_2": HopperBenchmark.ONE_DIM_PARAMS_BOUND_2,
        "TWO_DIM_PARAMS_BOUND_3_3": HopperBenchmark.TWO_DIM_PARAMS_BOUND_3_3,
        "THREE_DIM_PARAMS_BOUND_3_3_4": HopperBenchmark.THREE_DIM_PARAMS_BOUND_3_3_4,
    },
    "Humanoid": {
        "ONE_DIM_PARAMS_BOUND_16": HumanoidStandUpBenchmark.ONE_DIM_PARAMS_BOUND_16,
        "TWO_DIM_PARAMS_BOUND_16_8": HumanoidStandUpBenchmark.TWO_DIM_PARAMS_BOUND_16_8,
        "THREE_DIM_PARAMS_BOUND_16_5_8": HumanoidStandUpBenchmark.THREE_DIM_PARAMS_BOUND_16_5_8,
    },
    "InvertedPendulum": {
        "ONE_DIM_PARAMS_BOUND_31": InvertedPendulumBenchmark.ONE_DIM_PARAMS_BOUND_31,
        "ONE_DIM_PARAMS_BOUND_9": InvertedPendulumBenchmark.ONE_DIM_PARAMS_BOUND_9,
        "TWO_DIM_PARAMS_BOUND_31_11": InvertedPendulumBenchmark.TWO_DIM_PARAMS_BOUND_31_11,
    },
    "Walker": {
        "ONE_DIM_PARAMS_BOUND_4": Walker2dBenchmark.ONE_DIM_PARAMS_BOUND_4,
        "TWO_DIM_PARAMS_BOUND_4_5": Walker2dBenchmark.TWO_DIM_PARAMS_BOUND_4_5,
        "THREE_DIM_PARAMS_BOUND_4_5_6": Walker2dBenchmark.THREE_DIM_PARAMS_BOUND_4_5_6,
    },
}


MUJUCO_BOUNDS = {
    "Ant": {
        "mujoco-vanilla": {"mass_coef": [0.5, 2], "friction_coef": [0.5, 2.5]},
        "mujoco-restrict": {"mass_coef": [0.5, 1.5], "friction_coef": [0.5, 2]},
    },
    "Halfcheetah": {
        "mujoco-vanilla": {"mass_coef": [0.5, 2], "friction_coef": [0.5, 2.5]},
        "mujoco-restrict": {"mass_coef": [0.5, 1.5], "friction_coef": [0.5, 2]},
    },
    "Hopper": {
        "mujoco-vanilla": {"mass_coef": [0.5, 2], "friction_coef": [0.5, 2.5]},
        "mujoco-restrict": {"mass_coef": [0.5, 1.5], "friction_coef": [0.5, 2]},
    },
    "Humanoid": {
        "mujoco-vanilla": {"mass_coef": [0.5, 2], "friction_coef": [0.5, 2.5]},
        "mujoco-restrict": {"mass_coef": [0.5, 1.5], "friction_coef": [0.5, 2]},
    },
    "InvertedPendulum": {
        "mujoco-vanilla": {"mass_coef": [0.5, 2], "friction_coef": [0.5, 2.5]},
        "mujoco-restrict": {"mass_coef": [0.5, 1.5], "friction_coef": [0.5, 2]},
    },
    "Walker": {
        "mujoco-vanilla": {"mass_coef": [0.5, 2], "friction_coef": [0.5, 2.5]},
        "mujoco-restrict": {"mass_coef": [0.5, 1.5], "friction_coef": [0.5, 2]},
    },
}


BENCHMARK_MUJOCO_REFERENCE = {
    "Ant": {
        "ONE_DIM_PARAMS_BOUND_3": {"torsomass": 0.33},
        "TWO_DIM_PARAMS_BOUND_3_3": {"frontleftlegmass": 0.04, "torsomass": 0.33},
        "THREE_DIM_PARAMS_BOUND_3_3_3": {
            "torsomass": 0.33,
            "frontleftlegmass": 0.04,
            "frontrightlegmass": 0.06,
        },
    },
    "Halfcheetah": {
        "ONE_DIM_PARAMS_BOUND_3": {"worldfriction": 0.4},
        "ONE_DIM_PARAMS_BOUND_4": {"worldfriction": 0.4},
        "TWO_DIM_PARAMS_BOUND_4_7": {"worldfriction": 0.4, "torsomass": 6.36},
        "THREE_DIM_PARAMS_BOUND_3_7_4": {
            "worldfriction": 0.4,
            "torsomass": 6.36,
            "backthighmass": 1.53,
        },
    },
    "Hopper": {
        "ONE_DIM_PARAMS_BOUND_3": {"worldfriction": 1},
        "ONE_DIM_PARAMS_BOUND_2": {"worldfriction": 1},
        "TWO_DIM_PARAMS_BOUND_3_3": {"worldfriction": 1, "torsomass": 3.53},
        "THREE_DIM_PARAMS_BOUND_3_3_4": {
            "worldfriction": 1,
            "torsomass": 3.53,
            "thighmass": 0.5,
        },
    },
    "Humanoid": {
        "ONE_DIM_PARAMS_BOUND_16": {"torsomass": 8.32},
        "TWO_DIM_PARAMS_BOUND_16_8": {"torsomass": 8.32, "rightfootmass": 1.77},
        "THREE_DIM_PARAMS_BOUND_16_5_8": {
            "torsomass": 8.32,
            "rightfootmass": 1.77,
            "leftthighmass": 4.53,
        },
    },
    "InvertedPendulum": {
        "ONE_DIM_PARAMS_BOUND_31": {"cartmass": 4.90},
        "ONE_DIM_PARAMS_BOUND_9": {
            "cartmass": 4.90,
        },
        "TWO_DIM_PARAMS_BOUND_31_11": {"cartmass": 4.90, "polemass": 9.42},
    },
    "Walker": {
        "ONE_DIM_PARAMS_BOUND_4": {"worldfriction": 0.7},
        "TWO_DIM_PARAMS_BOUND_4_5": {"worldfriction": 0.7, "torsomass": 3.53},
        "THREE_DIM_PARAMS_BOUND_4_5_6": {
            "worldfriction": 0.7,
            "torsomass": 3.53,
            "thighmass": 3.93,
        },
    },
}


def _get_center_bounds(bounds: Dict[str, Annotated[List[float], 2]]):
    return {k: (v[0] + v[1]) / 2 for k, v in bounds.items()}


def get_radius_bounds(
    parameters: Dict[str, float],
    radius: float,
    bounds: Dict[str, Annotated[List[float], 2]],
):
    # Return new bounds with radius of square of center parameters inside the original bounds
    parameters_bound_radius = {
        k: [v - radius, v + radius] for k, v in parameters.items()
    }
    # Check if the new bounds are inside the original bounds if not, set the new bounds to the original bounds
    for k, v in parameters_bound_radius.items():
        if v[0] < bounds[k][0]:
            v[0] = bounds[k][0]
        if v[1] > bounds[k][1]:
            v[1] = bounds[k][1]
    return parameters_bound_radius
