import argparse
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import gymnasium as gym
import numpy as np
import optuna
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv

from open_loop.cpg import Oscillators
from open_loop.external_force import ExternalForce
from open_loop.pd_controller import PDController


class BaseModel:
    env_id: str
    kp: float
    kd: float
    n_joints: int
    n_dim: int
    name: str
    xml_name: Optional[str] = None  # use default one
    mujoco_env: MujocoEnv
    desired_q_pos: np.ndarray
    cpg_dt: float = 0.001  # 1kHz
    dt: float  # model dt
    skip_pd: bool = False

    @property
    def xml_file(self) -> Optional[str]:
        if self.xml_name is None:
            return None
        # Patch the xml file so it contains site for the foot tips
        # Only useful for visualization and IK
        return str(Path(__file__).parent / "models" / self.xml_name)

    @property
    def env_kwargs(self) -> Dict[str, Any]:
        if self.xml_file is None:
            return {}
        return dict(xml_file=self.xml_file)

    def make_env(
        self,
        render: bool = False,
        additional_env_kwargs: Optional[Dict[str, Any]] = None,
        apply_external_force: bool = False,
    ) -> PDController:
        render_mode = "human" if render else None
        env_kwargs = self.env_kwargs
        env_kwargs.update(additional_env_kwargs or {})
        env = gym.make(self.env_id, render_mode=render_mode, **env_kwargs)
        if apply_external_force:
            env = ExternalForce(env, proba=0.05, force=np.array([5, 5, 5]))
        return self.init_model(env)

    def init_model(self, env: gym.Env) -> PDController:
        # Remove time limit and other wrappers
        mujoco_env = env.unwrapped
        if self.skip_pd:
            # Bert env
            self.dt = mujoco_env.wanted_dt  # type: ignore[attr-defined]
        else:
            assert isinstance(mujoco_env, MujocoEnv)
            self.mujoco_env = mujoco_env
            self.dt = mujoco_env.dt
            self.desired_q_pos = mujoco_env.init_qpos
        return self.init_pd_controller(env)

    def init_pd_controller(self, env: gym.Env) -> PDController:
        return PDController(env, n_joints=self.n_joints, kp=self.kp, kd=self.kd, skip_pd=self.skip_pd)

    def get_values(self, params: Dict[str, float], name: str, default: np.ndarray) -> np.ndarray:
        if f"{name}_0" in params:
            return np.array([params[f"{name}_{idx}"] for idx in range(self.n_dim)])
        return default

    def create_oscillators(self, params: Dict[str, float]) -> Oscillators:
        # Create coupling matrix
        phase_shifts = self.get_values(params, "phase_shift", np.zeros(self.n_dim))
        omega_swing = self.get_values(params, "omega_swing", np.ones(self.n_dim) * params["omega_swing"])
        omega_stance = self.get_values(params, "omega_stance", np.ones(self.n_dim) * params["omega_stance"])

        return Oscillators(
            omega_swing=omega_swing * 2 * np.pi,
            omega_stance=omega_stance * 2 * np.pi,
            phase_shifts=phase_shifts * 2 * np.pi,
            time_step=self.cpg_dt,
            n_dim=self.n_dim,
        )

    def sample_params(self, trial: optuna.Trial, sample_coupling: bool = False) -> Dict[str, Any]:
        """
        Sampler for oscillators hyperparameters.

        :param trial: Optuna trial object
        :param sample_coupling: Whether to additionally sample the coupling (phase shift)
            between the oscillators
        :return: The sampled hyperparameters for the given trial.
        """
        raise NotImplementedError()

    def read_params(self) -> Dict[str, float]:
        with open(Path(__file__).parent / "params" / f"{self.env_id}.txt") as file:
            lines = [line.strip().split(":") for line in file.readlines()]
        return {key: float(value) for (key, value) in lines}

    def evaluate(
        self,
        params: Dict[str, float],
        env: PDController,
        n_eval_episodes: int,
        render: bool = False,
        verbose: int = 0,
        log_path: Optional[str] = None,
    ) -> Tuple[float, int]:
        # reset the environment
        obs, _ = env.reset()

        if "kp" in params:
            env.kp = params["kp"]
            env.kd = params["kd"]

        n_episodes: int = 0
        episode_reward: float = 0.0
        episode_rewards: List[float] = []
        total_steps: int = 0
        oscillators = self.create_oscillators(params)

        while n_episodes < n_eval_episodes:
            if render:
                env.render()

            # Account for the difference between control frequency and dt
            # used to integrate oscillators equation
            n_cpg_step_per_control_step = max(int(self.dt / self.cpg_dt), 1)
            oscillators.update(n_cpg_step_per_control_step)

            if "amplitude_first" in params:
                amplitudes = np.zeros(self.n_dim)
                for idx in range(self.n_dim):
                    amplitudes[idx] = params["amplitude_first"] if (idx % 2 == 0) else params["amplitude_second"]
            else:
                amplitudes = self.get_values(params, "amplitude", np.ones(self.n_dim))

            offsets = self.get_values(params, "offset", np.zeros(self.n_dim))

            desired_qpos = amplitudes * np.sin(oscillators.theta) + offsets

            # For debug: set the state without a PD controller
            # import ipdb; ipdb.set_trace()
            # self.mujoco_env.set_state(q_pos, self.mujoco_env.init_qvel)
            obs, reward, terminated, truncated, info = env.step(desired_qpos)
            episode_reward += float(reward)
            total_steps += 1

            if terminated or truncated:
                episode_rewards.append(episode_reward)
                if verbose >= 2:
                    print(f"Episode finished with total reward {episode_reward:.2f}")
                obs, _ = env.reset()
                oscillators = self.create_oscillators(params)
                episode_reward = 0.0
                n_episodes += 1

        mean_episode_reward = np.mean(episode_rewards)
        std_reward = np.std(episode_rewards)
        if verbose >= 1:
            print(f"Evaluation finished with mean reward {mean_episode_reward:.2f} +/- {std_reward:.2f}")

        if log_path is not None:
            np.savez(log_path, episode_rewards=episode_rewards)

        return float(mean_episode_reward), total_steps


class StoreDict(argparse.Action):
    """
    Custom argparse action for storing dict.

    In: args1:0.0 args2:"dict(a=1)"
    Out: {'args1': 0.0, arg2: dict(a=1)}
    """

    def __call__(
        self,
        parser: argparse.ArgumentParser,
        namespace: argparse.Namespace,
        values: Union[str, Sequence, None],
        option_string: Optional[str] = None,
    ):
        arg_dict: Dict[str, Any] = {}
        if values is None or isinstance(values, str):
            return arg_dict
        for arguments in values:
            key = arguments.split(":")[0]
            value = ":".join(arguments.split(":")[1:])
            # Evaluate the string as python code
            arg_dict[key] = eval(value)
        setattr(namespace, self.dest, arg_dict)
