from typing import Any, Dict

import numpy as np
import optuna

from open_loop.base_model import BaseModel


class Swimmer(BaseModel):
    env_id: str = "Swimmer-v4"
    kp: float = 6.8
    kd: float = 0.74
    n_joints: int = 2
    n_dim: int = 2

    def sample_params(self, trial: optuna.Trial, sample_coupling: bool = False) -> Dict[str, Any]:
        omega_swing = trial.suggest_float("omega_swing", 0.1, 2)
        omega_stance = trial.suggest_float("omega_stance", 0.1, 2)

        params = {}
        if sample_coupling:
            # Phase shifts are relative
            phase_shifts = np.zeros(4)
            phase_shifts[0] = 0.0
            for idx in range(1, self.n_dim):
                phase_shifts[idx] = trial.suggest_float(f"phase_shift_{idx}", 0.0, 1.0)

            params = {f"phase_shift_{idx}": phase_shifts[idx] for idx in range(self.n_dim)}

        params.update(
            {
                "omega_swing": omega_swing,
                "omega_stance": omega_stance,
            }
        )
        return params
