from typing import Optional, Protocol, Dict, Tuple

import gym
from gym.envs.mujoco import HalfCheetahEnv
from gym.envs.mujoco.inverted_pendulum import InvertedPendulumEnv
from gym.envs.mujoco.walker2d import Walker2dEnv
from gym.envs.mujoco.hopper import HopperEnv
from gym.envs.mujoco.ant import AntEnv
from gym.envs.mujoco.humanoidstandup import HumanoidStandupEnv

import numpy as np


class ModifiedPhysics(Protocol):
    def change_physics(self, **kwargs):
        ...

    def set_params(self, **kwargs):
        ...

    def get_params(self):
        ...


# Python doesnt support type addition so we need to create a new class
# Unlike Rust where we can just add to the function signature ModifiedPhysics + gym.Env
class ModifiedPhysicsEnv(ModifiedPhysics, gym.Env):
    def change_physics(self, **kwargs):
        ...

    def set_params(self, **kwargs):
        ...

    def get_params(self):
        ...


class DomainRandomizationBenchmarkWrapper(gym.Wrapper):
    def __init__(
        self, env: ModifiedPhysicsEnv, params_bound: Dict[str, Tuple[float, float]]
    ):
        super().__init__(env)
        self.env = env
        self.params_bound = params_bound
        self.params = self.draw_params()

    def reset(self):
        self.params = self.draw_params()
        self.env.set_params(**self.params)
        self.env.change_physics()
        return self.env.reset()

    def step(self, action):
        return self.env.step(action)

    def draw_params(self):
        low = np.array([bound[0] for bound in self.params_bound.values()])
        high = np.array([bound[1] for bound in self.params_bound.values()])
        params_draw = np.random.uniform(low, high)
        self.params = dict(zip(self.params_bound.keys(), params_draw))
        return self.params


class HalfCheetahBenchmark(HalfCheetahEnv):
    ONE_DIM_PARAMS_BOUND_3 = {
        "worldfriction": [0.1, 3.0],
    }
    ONE_DIM_PARAMS_BOUND_4 = {
        "worldfriction": [0.1, 4.0],
    }

    TWO_DIM_PARAMS_BOUND_4_7 = {
        "worldfriction": [0.1, 4.0],
        "torsomass": [0.1, 7.0],
    }
    THREE_DIM_PARAMS_BOUND_3_7_4 = {
        "worldfriction": [0.1, 4.0],
        "torsomass": [0.1, 7.0],
        "backthighmass": [0.1, 3.0],
    }

    def __init__(
        self,
        worldfriction: Optional[float] = None,
        torsomass: Optional[float] = None,
        backthighmass: Optional[float] = None,
    ):
        self.worldfriction = worldfriction
        self.torsomass = torsomass
        self.backthighmass = backthighmass
        super().__init__()

        self.change_physics()

    def set_params(
        self,
        worldfriction: Optional[float] = None,
        torsomass: Optional[float] = None,
        backthighmass: Optional[float] = None,
    ):
        self.worldfriction = worldfriction
        self.torsomass = torsomass
        self.backthighmass = backthighmass

    def get_params(self):
        return {
            "worldfriction": self.worldfriction,
            "torsomass": self.torsomass,
            "backthighmass": self.backthighmass,
        }

    def step(self, action):
        obs, reward, done, info = super().step(action)
        info.update(self.get_params())
        return obs, reward, done, info

    def change_physics(self):
        if self.worldfriction is not None:
            self.model.geom_friction[:, 0] = self.worldfriction
        if self.torsomass is not None:
            self.model.body_mass[1] = self.torsomass
        if self.backthighmass is not None:
            self.model.body_mass[2] = self.backthighmass


class InvertedPendulumBenchmark(InvertedPendulumEnv):
    ONE_DIM_PARAMS_BOUND_31 = {
        "polemass": [1, 31],
    }
    ONE_DIM_PARAMS_BOUND_9 = {
        "polemass": [1, 9],
    }

    TWO_DIM_PARAMS_BOUND_31_11 = {
        "polemass": [1, 31],
        "cartmass": [1, 11],
    }

    def __init__(
        self, polemass: Optional[float] = None, cartmass: Optional[float] = None
    ):
        self.polemass = polemass
        self.cartmass = cartmass
        super().__init__()

        self.change_physics()

    def set_params(
        self, polemass: Optional[float] = None, cartmass: Optional[float] = None
    ):
        self.polemass = polemass
        self.cartmass = cartmass

    def get_params(self):
        return {
            "polemass": self.polemass,
            "cartmass": self.cartmass,
        }

    def step(self, a):
        obs, reward, done, info = super().step(a)
        info.update(self.get_params())
        return obs, reward, done, info

    def change_physics(self):
        if self.polemass is not None:
            self.model.body_mass[2] = self.polemass
        if self.cartmass is not None:
            self.model.body_mass[1] = self.cartmass


class Walker2dBenchmark(Walker2dEnv):
    ONE_DIM_PARAMS_BOUND_4 = {
        "worldfriction": [0.1, 4.0],
    }
    TWO_DIM_PARAMS_BOUND_4_5 = {
        "worldfriction": [0.1, 4.0],
        "torsomass": [0.1, 5.0],
    }
    THREE_DIM_PARAMS_BOUND_4_5_6 = {
        "worldfriction": [0.1, 4.0],
        "torsomass": [0.1, 5.0],
        "thighmass": [0.1, 6.0],
    }

    def __init__(
        self,
        worldfriction: Optional[float] = None,
        torsomass: Optional[float] = None,
        thighmass: Optional[float] = None,
    ):
        self.worldfriction = worldfriction
        self.torsomass = torsomass
        self.thighmass = thighmass
        super().__init__()

        self.change_physics()

    def set_params(
        self,
        worldfriction: Optional[float] = None,
        torsomass: Optional[float] = None,
        thighmass: Optional[float] = None,
    ):
        self.worldfriction = worldfriction
        self.torsomass = torsomass
        self.thighmass = thighmass

    def get_params(self):
        return {
            "worldfriction": self.worldfriction,
            "torsomass": self.torsomass,
            "thighmass": self.thighmass,
        }

    def step(self, action):
        obs, reward, done, info = super().step(action)
        info.update(self.get_params())
        return obs, reward, done, info

    def change_physics(self):
        if self.worldfriction is not None:
            self.model.geom_friction[0, 0] = self.worldfriction

        if self.torsomass is not None:
            self.model.body_mass[1] = self.torsomass

        if self.thighmass is not None:
            self.model.body_mass[2] = self.thighmass


class HopperBenchmark(HopperEnv):
    ONE_DIM_PARAMS_BOUND_2 = {
        "worldfriction": [0.1, 2.0],
    }
    ONE_DIM_PARAMS_BOUND_3 = {
        "worldfriction": [0.1, 3.0],
    }
    TWO_DIM_PARAMS_BOUND_3_3 = {
        "worldfriction": [0.1, 3.0],
        "torsomass": [0.1, 3.0],
    }
    THREE_DIM_PARAMS_BOUND_3_3_4 = {
        "worldfriction": [0.1, 3.0],
        "torsomass": [0.1, 3.0],
        "thighmass": [0.1, 4.0],
    }

    def __init__(
        self,
        worldfriction: Optional[float] = None,
        torsomass: Optional[float] = None,
        thighmass: Optional[float] = None,
    ):
        self.worldfriction = worldfriction
        self.torsomass = torsomass
        self.thighmass = thighmass
        super().__init__()

        self.change_physics()

    def set_params(
        self,
        worldfriction: Optional[float] = None,
        torsomass: Optional[float] = None,
        thighmass: Optional[float] = None,
    ):
        self.worldfriction = worldfriction
        self.torsomass = torsomass
        self.thighmass = thighmass

    def get_params(self):
        return {
            "worldfriction": self.worldfriction,
            "torsomass": self.torsomass,
            "thighmass": self.thighmass,
        }

    def step(self, action):
        obs, reward, done, info = super().step(action)
        info.update(self.get_params())
        return obs, reward, done, info

    def change_physics(self):
        if self.worldfriction is not None:
            self.model.geom_friction[0, 0] = self.worldfriction

        if self.torsomass is not None:
            self.model.body_mass[1] = self.torsomass

        if self.thighmass is not None:
            self.model.body_mass[2] = self.thighmass


class AntBenchmark(AntEnv):
    ONE_DIM_PARAMS_BOUND_3 = {
        "torsomass": [0.1, 3.0],
    }
    TWO_DIM_PARAMS_BOUND_3_3 = {
        "torsomass": [0.1, 3.0],
        "frontleftlegmass": [0.01, 3.0],
    }
    THREE_DIM_PARAMS_BOUND_3_3_3 = {
        "torsomass": [0.1, 3.0],
        "frontleftlegmass": [0.01, 3.0],
        "frontrightlegmass": [0.01, 3.0],
    }

    def __init__(
        self,
        torsomass: Optional[float] = None,
        frontleftlegmass: Optional[float] = None,
        frontrightlegmass: Optional[float] = None,
    ):
        self.torsomass = torsomass
        self.frontleftlegmass = frontleftlegmass
        self.frontrightlegmass = frontrightlegmass
        super().__init__()

        self.change_physics()

    def set_params(
        self,
        torsomass: Optional[float] = None,
        frontleftlegmass: Optional[float] = None,
        frontrightlegmass: Optional[float] = None,
    ):
        self.torsomass = torsomass
        self.frontleftlegmass = frontleftlegmass
        self.frontrightlegmass = frontrightlegmass

    def get_params(self):
        return {
            "torsomass": self.torsomass,
            "frontleftlegmass": self.frontleftlegmass,
            "frontrightlegmass": self.frontrightlegmass,
        }

    def step(self, action):
        obs, reward, done, info = super().step(action)
        info.update(self.get_params())
        return obs, reward, done, info

    def change_physics(self):
        if self.torsomass is not None:
            self.model.body_mass[1] = self.torsomass
        if self.frontleftlegmass is not None:
            self.model.body_mass[2] = self.frontleftlegmass
        if self.frontrightlegmass is not None:
            self.model.body_mass[4] = self.frontrightlegmass


class HumanoidStandUpBenchmark(HumanoidStandupEnv):
    ONE_DIM_PARAMS_BOUND_16 = {
        "torsomass": [0.1, 16.0],
    }
    TWO_DIM_PARAMS_BOUND_16_8 = {
        "torsomass": [0.1, 16.0],
        "rightfootmass": [0.1, 8.0],
    }
    THREE_DIM_PARAMS_BOUND_16_5_8 = {
        "torsomass": [0.1, 16.0],
        "leftthighmass": [0.1, 5.0],
        "rightfootmass": [0.1, 8.0],
    }

    def __init__(
        self,
        torsomass: Optional[float] = None,
        leftthighmass: Optional[float] = None,
        rightfootmass: Optional[float] = None,
    ):
        self.torsomass = torsomass
        self.leftthighmass = leftthighmass
        self.rightfootmass = rightfootmass
        super().__init__()

        self.change_physics()

    def set_params(
        self,
        torsomass: Optional[float] = None,
        leftthighmass: Optional[float] = None,
        rightfootmass: Optional[float] = None,
    ):
        self.torsomass = torsomass
        self.leftthighmass = leftthighmass
        self.rightfootmass = rightfootmass

    def get_params(self):
        return {
            "torsomass": self.torsomass,
            "leftthighmass": self.leftthighmass,
            "rightfootmass": self.rightfootmass,
        }

    def step(self, action):
        obs, reward, done, info = super().step(action)
        info.update(self.get_params())
        return obs, reward, done, info

    def change_physics(self):
        if self.torsomass is not None:
            self.model.body_mass[1] = self.torsomass

        if self.leftthighmass is not None:
            self.model.body_mass[7] = self.leftthighmass

        if self.rightfootmass is not None:
            self.model.body_mass[6] = self.rightfootmass
