import os
import time
from abc import ABC
from abc import abstractmethod

import gym
import matplotlib
import numpy as np
import torch
from gym import utils
from gym.envs.mujoco import mujoco_env
from matplotlib import animation
from matplotlib import pyplot as plt

from ..utils.env_utils import create_xml
from ..utils.env_utils import load_default_params
from ..utils.env_utils import save_metrics
from ..utils.o2m_utils import set_parameters


try:
    from cluster import announce_early_results, announce_fraction_finished

    CLUSTER = 1
except ImportError:
    CLUSTER = 0


class AbstractMuscleEnv(mujoco_env.MujocoEnv, utils.EzPickle, ABC):
    @abstractmethod
    def reset_model(self):
        pass

    @abstractmethod
    def viewer_setup(self):
        pass

    @abstractmethod
    def set_target(self):
        pass


class MuscleEnv(AbstractMuscleEnv):
    def __init__(self, identifier=None):
        # initialise this with large number for env creation
        self.manually_set_action_space = 0
        self.render_substep_bool = 0
        default_path = self.get_default_params_path()
        _, args = load_default_params(path=default_path)
        self.args = args
        self.identifier = identifier
        self.frameskip = args.frameskip
        self.quick_settings(args, identifier)
        self.reset()
        # self.render_substep()

    def render_substep(self):
        self.render_substep_bool = 1

    def do_simulation(self, ctrl, n_frames):
        if not hasattr(self, "action_multiplier"):
            if np.array(ctrl).shape != self.action_space.shape:
                raise ValueError("Action dimension mismatch")

        self.sim.data.ctrl[:] = ctrl
        for _ in range(n_frames):
            if self.render_substep_bool:
                # self.render('rgb_array')
                self.render("human")
            self.sim.step()

    def get_default_params_path(self):
        default_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
        return os.path.join(default_path, f"param_files/{self.model_type}.json")

    def set_episode_length(self, episode_length: int):
        """Set maximum episode length. Have to use this and NOT
        the gym registration mechanism as that triggers terminal timeouts
        which are NOT correct."""
        self._max_episode_steps = episode_length

    def merge_args(self, args):
        for k, v in args.items():
            setattr(self.args, k, v)

    def apply_args(self):
        self.quick_settings(self.args, self.identifier)

    def quick_settings(self, args, identifier):
        """Applies correct settings from args.
        Don't change the order if you don't know what you are doing.
        Some settings are used by the mujoco_env initilization."""
        self.set_target(args.target)
        self.set_random_goals(args.random_goals)
        self.set_sparse_reward(args.sparse_reward)
        self.set_termination(args.termination)
        self.set_termination_distance(args.termination_distance)
        self.set_dynamic_load(args.dynamic_load)
        self.set_actuator(args.actuator)
        self.set_episode_length(args.episode_length)
        self.reinitialise(args, identifier)
        self.set_action_space_multiplier(args.action_multiplier)
        self.set_muscle_control(args.muscle_morphology, args.morph_settings)
        self.maybe_set_o2m_params()
        self.set_gravity(args.gravity)
        if hasattr(args, "force_scale"):
            self.set_force_scale(args.force_scale)

    def set_force_scale(self, scale=1.0):
        self.force_scale = scale

    def maybe_set_o2m_params(self):
        if hasattr(self, "o2m_params"):
            params = torch.load(self.o2m_params)
            set_parameters(
                self.model,
                params["parameters"],
                params["muscle_idxs"],
                params["joint_idxs"],
            )

    def apply_muscle_settings(self):
        """
        Apply settings=[FL, FV, DYN] as [FL_active, FV, FL_passive, DYN]
        in environment.
        """
        self.sim.data.userdata[: len(self.settings) + 1] = [
            self.settings[0],
            self.settings[1],
            self.settings[0],
            self.settings[2],
        ]

    def set_muscle_control(self, actuator="mujoco", settings=[1, 1, 1]):
        """
        <settings> describes [Force_length, Force_velocity, activation_dynamics]
        Which one of these things do you want to activate? For mujoco-defaults instead
        of user-written functions, use <actuator='mujoco'>, else <actuator='user'>.
        The mujoco defaults ignore the <settings> attribute.
        """
        # Attention, mutable default argument, don't change!
        self.settings = np.array(settings).copy()
        if actuator == "mujoco":
            pass

        elif actuator == "user":
            try:
                from mujoco_py import cymj

                cymj.set_muscle_control(self.sim.model, self.sim.data)
            except Exception:
                raise Exception(
                    "User-defined muscle actuator <{args.muscle_morphology}> could not be loaded! Is the\
                                custom mujoco-py branch\
                                installed?"
                )

    def set_random_goals(self, random_goals=False):
        """Should the goals be randomly sampled."""
        self.random_goals = random_goals

    def set_termination_distance(self, distance=0.08):
        """Set endeffector to goal distance at which the episode is considered to be solved."""
        self.termination_distance = distance

    def set_termination(self, termination=False):
        """Decide wether the episode will be prematurely terminated when achieving the goal.
        A <done> signal will be emitted."""
        self.termination = termination

    def set_dynamic_load(self, dynamic_load=False):
        self.dynamic_load = dynamic_load

    def set_action_space_multiplier(self, n):
        """Creates a virtual large action space. Shadow action
        get averaged back into original action during exectution.
        Original action space gets multiplied by n."""
        assert type(n) is int
        assert n != 0
        # dont change anything if n==1
        if n != 1:
            previous_shape = self.action_space.shape[-1]
            low = self._create_bounds(self.action_space.low, n)
            high = self._create_bounds(self.action_space.high, n)
            self.internal_action = np.zeros(self.action_space.shape)
            self.action_space = gym.spaces.Box(
                low=low, high=high, shape=(previous_shape * n,)
            )
            self.action_multiplier = n
            self.observation_space = gym.spaces.Box(
                -np.inf, np.inf, shape=self._get_obs().shape
            )
            self.manually_set_action_space = 1

    def _create_bounds(self, bound, n):
        """
        Create action_space bound vector after action space
        has been multiplied.
        """
        new_bound = []
        for b in bound:
            for _ in range(n):
                new_bound.append(b)
        return np.array(new_bound)

    def set_actuator(self, actuator="muscle"):
        """
        Set actuatore type, e.g. <muscle>, <motor>. Remember, motor
        is just an ideal torque generator in mujoco in most cases.
        """
        self.actuator = actuator
        self.need_reinit = 1

    def reinitialise(self, args, identifier):
        self.need_reinit = 0
        while True:
            path = create_xml(args, identifier)
            try:
                # second one is frameskip
                mujoco_env.MujocoEnv.__init__(self, path, self.frameskip)
                break
            except FileNotFoundError:
                print("xml file not found, reentering loop.")
        utils.EzPickle.__init__(self)
        try:
            os.remove(path)
        except FileNotFoundError:
            print(
                "Interference with model_<id>.xml file removing. Check generated_xml folder."
            )

    def set_sparse_reward(self, sparse_reward=False):
        """
        Do you want a sparse reward.
        """
        self.sparse_reward = sparse_reward

    def set_gravity(self, value):
        """
        Set gravity, orientation depends on envrionment.
        """
        for idx, val in enumerate(value):
            self.model.opt.gravity[idx] = val

    def redistribute_action(self, action):
        """
        Imagine an action space where part of the actions are summed up to contribute to a single
        (muscle) stimulation. This is the simplest form of a pseudo-higher dimensional action_space
        Assume multiplied actions are side-by-side
        e.g. [1,2,3,4] -> [1,1,2,2,3,3,4,4]
        """
        action = np.clip(action.copy(), self.action_space.low, self.action_space.high)
        for k in range(0, self.internal_action.shape[-1]):
            self.internal_action[k] = (
                np.sum(
                    action[
                        k * self.action_multiplier : k * self.action_multiplier
                        + self.action_multiplier
                    ]
                )
                / self.action_multiplier
            )
        return self.internal_action

    def randomise_init_state(self, diff=0.01):
        """
        Randomises initial joint positions slightly.
        """
        # TODO remove uncontrollable joint from randomisation
        qpos = self.init_qpos
        qvel = self.init_qvel
        qvel = np.random.normal(0.0, diff, size=(self.model.nq))
        self.set_state(qpos, qvel)

    def seed(self, seed=None):
        self._seed = seed
        seed = 0 if seed is None else seed
        super().seed(seed)
        np.random.seed(seed)

    @property
    def dt(self):
        return self.model.opt.timestep * self.frameskip

    @property
    def max_episode_steps(self):
        return self._max_episode_steps

    @property
    def process_state(self):
        return (self.full_state, self.joint_state)

    @property
    def full_state(self):
        return self._get_obs()

    @property
    def joint_state(self):
        return self.data.qpos[: self.unwrapped.nq]

    def display_video(self, framerate=60):
        height, width, _ = self.frames[0].shape
        dpi = 100
        orig_backend = matplotlib.get_backend()
        matplotlib.use("Agg")  # Switch to headless 'Agg' to inhibit figure rendering.
        fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
        matplotlib.use(orig_backend)  # Switch back to the original backend.
        ax.set_axis_off()
        ax.set_aspect("equal")
        ax.set_position([0, 0, 1, 1])
        im = ax.imshow(self.frames[0])

        def update(frame):
            im.set_data(frame)
            return [im]

        interval = 1000 / framerate
        anim = animation.FuncAnimation(
            fig=fig,
            func=update,
            frames=self.frames,
            interval=interval,
            blit=True,
            repeat=False,
        )
        writervideo = animation.FFMpegWriter(fps=framerate)
        anim.save(f"video_render_{time.time()}.mp4", writer=writervideo)

    def reset(self, *args, **kwargs):
        if self.render_substep_bool:
            if hasattr(self, "frames") and len(self.frames) != 0:
                self.display_video()
            self.frames = []
        return super().reset(*args, **kwargs)

    def _get_obs(self):
        """get_obs() implicitly calls this function to get the
        MDP state in the default case."""
        act = self.muscle_activations()
        if act is None:
            act = np.zeros_like(self.muscle_lengths())
        return np.concatenate(
            [
                self.sim.data.qpos[: self.nq],
                self.sim.data.qvel[: self.nq],
                self.muscle_lengths(),
                self.muscle_forces(),
                self.muscle_velocities(),
                act,
                self.target,
                self.sim.data.get_site_xpos(self.tracking_str),
            ]
        )

    @property
    def muscles_dep(self):
        lce = self.muscle_lengths()
        f = self.muscle_forces()
        if not hasattr(self, "max_muscle"):
            self.max_muscle = np.zeros_like(lce)
            self.min_muscle = np.ones_like(lce) * 100.0
            self.max_force = -np.ones_like(f) * 100.0
            self.min_force = np.ones_like(f) * 100.0
        self.max_muscle = np.maximum(lce, self.max_muscle)
        self.min_muscle = np.minimum(lce, self.min_muscle)
        self.max_force = np.maximum(f, self.max_force)
        self.min_force = np.minimum(f, self.min_force)
        return 1.0 * (
            ((lce - self.min_muscle) / (self.max_muscle - self.min_muscle + 0.1)) - 0.5
        ) * 2.0 + self.force_scale * (
            (f - self.min_force) / (self.max_force - self.min_force + 0.1)
        )

    def muscle_lengths(self):
        if not hasattr(self, "action_multiplier") or self.action_multiplier == 1:
            return self.data.actuator_length.copy()
        return np.repeat(self.data.actuator_length.copy(), self.action_multiplier)

    def muscle_velocities(self):
        if not hasattr(self, "action_multiplier") or self.action_multiplier == 1:
            return np.clip(self.data.actuator_velocity, -100, 100).copy()
        return np.repeat(
            np.clip(self.data.actuator_velocity, -100, 100).copy(),
            self.action_multiplier,
        )

    def muscle_activations(self):
        if not hasattr(self, "action_multiplier") or self.action_multiplier == 1:
            return np.clip(self.data.act, -100, 100).copy()
        return np.repeat(
            np.clip(self.data.act, -100, 100).copy(), self.action_multiplier
        )

    def muscle_forces(self):
        if not hasattr(self, "action_multiplier") or self.action_multiplier == 1:
            return np.clip(self.data.actuator_force / 1000, -100, 100).copy()
        return np.repeat(
            np.clip(self.data.actuator_force / 1000, -100, 100).copy(),
            self.action_multiplier,
        )

    @property
    def action_space(self):
        return self._action_space

    @action_space.setter
    def action_space(self, val):
        if not self.manually_set_action_space:
            self._action_space = val
        else:
            pass
