from typing import Any, Dict

import numpy as np
import optuna

from open_loop.base_model import BaseModel


class HalfCheetah(BaseModel):
    env_id: str = "HalfCheetah-v4"
    kp: float = 1.0
    kd: float = 0.05
    n_joints: int = 6
    n_dim: int = 6

    def sample_params(self, trial: optuna.Trial, sample_coupling: bool = False) -> Dict[str, Any]:
        omega_swing = trial.suggest_float("omega_swing", 0.4, 5)
        omega_stance = trial.suggest_float("omega_stance", 0.4, 5)

        params = {}
        if sample_coupling:
            # Phase shifts are relative
            phase_shifts = np.zeros(self.n_dim)
            phase_shifts[0] = 0.0
            for idx in range(1, self.n_dim):
                phase_shifts[idx] = trial.suggest_float(f"phase_shift_{idx}", 0.0, 1.0)

            params = {f"phase_shift_{idx}": phase_shifts[idx] for idx in range(self.n_dim)}

        for idx in range(self.n_dim):
            params[f"amplitude_{idx}"] = trial.suggest_float(f"amplitude_{idx}", -2.0, 2.0)
            params[f"offset_{idx}"] = trial.suggest_float(f"offset_{idx}", -1.0, 1.0)

        params.update(
            {
                "omega_swing": omega_swing,
                "omega_stance": omega_stance,
            }
        )
        return params
