"""
Baseline configurations
"""

import copy
from dataclasses import dataclass, field
from typing import List, Tuple

@dataclass
class BaselineSpec:
    """Keeps track of baseline specifications."""
    id: str
    kwargs: dict = field(default_factory=dict)
    wrappers: List[Tuple[object, dict]] = field(default_factory=list) # [(class, kwargs), ...]
    metrics: List[str] = field(default_factory=list)

    def make(self):
        """Make the environment, adding any specified wrappers as well."""
        import gymnasium as gym
        env = gym.make(self.id, **self.kwargs)
        for (wrapper, wrap_kwargs) in reversed(self.wrappers):
            if isinstance(wrapper, str):
                wid = wrapper.lower().replace(" ", "")
                if wid in ["noisyactions", "actionnoise"]:
                    from latency_env.delayed_mdp import NoisyActionWrapper
                    wenv = NoisyActionWrapper(env, **wrap_kwargs)
                else:
                    raise ValueError(f"Unknown wrapper name \"{wrapper}\"")
            else:
                wenv = wrapper(env, **wrap_kwargs)
            env = wenv
        return env


CONTINUOUS_ENVS = {
    "bipedal": BaselineSpec(id="BipedalWalker-v3"),
    "lunarcont": BaselineSpec(
        id = "LunarLander-v2",
        kwargs = {"continuous": True},
        metrics = ["euclidean"]*4 + ["radial"] + ["euclidean"]*3,
    ),
    "pendulum": BaselineSpec(
        id = "Pendulum-v1",
        kwargs = {"g": 9.81},
        metrics = ["euclidean"]*3
    ),
    "ant": BaselineSpec(
        id = "Ant-v4",
        metrics = ["euclidean"] + ["radial"]*12 + ["euclidean"]*6 + ["radial"]*8,
    ),
    "humanoid": BaselineSpec(
        id = "Humanoid-v4",
        metrics = ["euclidean"]*1 + ["radial"]*21 + ["euclidean"]*23 + ["euclidean"]*(376 - (1 + 21 + 23)),
    ),
    "humanoidstandup": BaselineSpec(id="HumanoidStandup-v4"),
    "spider": BaselineSpec(id="Spider-v1"),
    "walker2d": BaselineSpec(
        id = "Walker2d-v4",
        metrics = ["euclidean"]*1 + ["radial"]*7 + ["euclidean"]*9,
    ),
    "furuta": BaselineSpec(
        id = "Qube2ODE-v1",
        metrics = ["radial"]*2 + ["euclidean"]*2,
    ),
    "hopper": BaselineSpec(
        id = "Hopper-v4",
        metrics = ["euclidean"] + ["radial"]*4 + ["euclidean"]*6,
    ),
    "halfcheetah": BaselineSpec(
        id = "HalfCheetah-v4",
        metrics = ["euclidean"] + ["radial"]*7 + ["euclidean"]*9,
    ),
    "invertedpendulum": BaselineSpec(
        id = "InvertedPendulum-v4",
        metrics = ["euclidean"] + ["radial"] + ["euclidean"]*2,
    ),
    "swimmer": BaselineSpec(
        id = "Swimmer-v4",
        metrics = ["radial"]*3 + ["euclidean"]*5,
    ),
    "reacher": BaselineSpec(
        id = "Reacher-v4",
        metrics = ["euclidean"]*11,
    ),
}

# Add noisy variants to all continuous environments
for k in list(CONTINUOUS_ENVS.keys()):
    spec = copy.deepcopy(CONTINUOUS_ENVS[k])
    spec.wrappers.append(("NoisyActions", {"noise": 0.05}))
    CONTINUOUS_ENVS[f"noisy-{k}"] = spec

    spec = copy.deepcopy(CONTINUOUS_ENVS[k])
    spec.wrappers.append(("NoisyActions", {"noise": 0.20}))
    CONTINUOUS_ENVS[f"verynoisy-{k}"] = spec

# Add noisy variants to all continuous environments
for k in list(CONTINUOUS_ENVS.keys()):
    for noisepromille in [25, 50, 75, 100, 125, 150, 175, 200, 225, 250]:
        spec = copy.deepcopy(CONTINUOUS_ENVS[k])
        spec.wrappers.append(("NoisyActions", {"noise": noisepromille/1000.0}))
        CONTINUOUS_ENVS[f"noisepromille{noisepromille}-{k}"] = spec
