from contextlib import redirect_stdout
from importlib_metadata import Deprecated
import progressbar
import matplotlib.pyplot as plt
import numpy as np
import torch
from gym.envs.mujoco import HalfCheetahEnv as HalfCheetahEnv_

from utils import helpers as utl
from scipy.stats import norm
from scipy.stats import multivariate_normal

from utils import helpers as utl

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")


class HalfCheetahEnv(HalfCheetahEnv_):
    def _get_obs(self):
        return np.concatenate([
            self.sim.data.qpos.flat[1:],
            self.sim.data.qvel.flat,
            self.get_body_com("torso").flat,
        ]).astype(np.float32).flatten()

    def viewer_setup(self):
        camera_id = self.model.camera_name2id('track')
        self.viewer.cam.type = 2
        self.viewer.cam.fixedcamid = camera_id
        self.viewer.cam.distance = self.model.stat.extent * 0.35
        # Hide the overlay
        self.viewer._hide_overlay = True

    def render(self, mode='human'):
        if mode == 'rgb_array':
            self._get_viewer().render()
            # window size used for old mujoco-py:
            width, height = 500, 500
            data = self._get_viewer().read_pixels(width, height, depth=False)
            return data
        elif mode == 'human':
            self._get_viewer().render()

    @staticmethod
    def visualise_behaviour(env,
                            args,
                            policy,
                            iter_idx,
                            encoder=None,
                            image_folder=None,
                            return_pos=False,
                            **kwargs,
                            ):

        raise Deprecated('visualise_behaviour is deprecated.')

        # num_episodes = args.max_rollouts_per_task
        # unwrapped_env = env.venv.unwrapped.envs[0].unwrapped

        # # --- initialise things we want to keep track of ---

        # episode_prev_obs = [[] for _ in range(num_episodes)]
        # episode_next_obs = [[] for _ in range(num_episodes)]
        # episode_actions = [[] for _ in range(num_episodes)]
        # episode_rewards = [[] for _ in range(num_episodes)]

        # episode_returns = []
        # episode_lengths = []

        # if encoder is not None:
        #     episode_latent_samples = [[] for _ in range(num_episodes)]
        #     episode_latent_means = [[] for _ in range(num_episodes)]
        #     episode_latent_logvars = [[] for _ in range(num_episodes)]
        # else:
        #     curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
        #     episode_latent_samples = episode_latent_means = episode_latent_logvars = None

        # #                     # second term is E_{ q(c|tau_{t-k:t-1}) } [ p(s_t, r_t-1 | s_t-1, a_t-1, c) ]

        # # (re)set environment
        # env.reset_task()
        # state, belief, task, info = utl.reset_env(env, args)
        # start_state = state.clone()

        # # if hasattr(args, 'hidden_size'):
        # #     hidden_state = torch.zeros((1, args.hidden_size)).to(device)
        # # else:
        # #     hidden_state = None

        # # keep track of what task we're in and the position of the cheetah
        # pos = [[] for _ in range(args.max_rollouts_per_task)]
        # start_pos = unwrapped_env.get_body_com("torso")[0].copy()

        # for episode_idx in range(num_episodes):

        #     curr_rollout_rew = []
        #     pos[episode_idx].append(start_pos)

        #     if encoder is not None:
        #         if episode_idx == 0:
        #             # reset to prior
        #             curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
        #             curr_latent_sample = curr_latent_sample[0].to(device)
        #             curr_latent_mean = curr_latent_mean[0].to(device)
        #             curr_latent_logvar = curr_latent_logvar[0].to(device)
        #         episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
        #         episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
        #         episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())

        #     for step_idx in range(1, env._max_episode_steps + 1):

        #         if step_idx == 1:
        #             episode_prev_obs[episode_idx].append(start_state.clone())
        #         else:
        #             episode_prev_obs[episode_idx].append(state.clone())
        #         # act
        #         latent = utl.get_latent_for_policy(args,
        #                                            latent_sample=curr_latent_sample,
        #                                            latent_mean=curr_latent_mean,
        #                                            latent_logvar=curr_latent_logvar)
        #         _, action = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, deterministic=True)

        #         (state, belief, task), (rew, rew_normalised), done, info = utl.env_step(env, action, args)
        #         state = state.reshape((1, -1)).float().to(device)

        #         # keep track of position
        #         pos[episode_idx].append(unwrapped_env.get_body_com("torso")[0].copy())

        #         if encoder is not None:
        #             # update task embedding
        #             curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
        #                 action.reshape(1, -1).float().to(device), state, rew.reshape(1, -1).float().to(device),
        #                 hidden_state, return_prior=False)

        #             episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
        #             episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
        #             episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())

        #         episode_next_obs[episode_idx].append(state.clone())
        #         episode_rewards[episode_idx].append(rew.clone())
        #         episode_actions[episode_idx].append(action.reshape(1, -1).clone())

        #         if info[0]['done_mdp'] and not done:
        #             start_state = info[0]['start_state']
        #             start_state = torch.from_numpy(start_state).reshape((1, -1)).float().to(device)
        #             start_pos = unwrapped_env.get_body_com("torso")[0].copy()
        #             break

        #     episode_returns.append(sum(curr_rollout_rew))
        #     episode_lengths.append(step_idx)

        # # clean up
        # if encoder is not None:
        #     episode_latent_means = [torch.stack(e) for e in episode_latent_means]
        #     episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]

        # episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
        # episode_next_obs = [torch.cat(e) for e in episode_next_obs]
        # episode_actions = [torch.cat(e) for e in episode_actions]
        # episode_rewards = [torch.cat(e) for e in episode_rewards]

        # # plot the movement of the half-cheetah
        # plt.figure(figsize=(7, 4 * num_episodes))
        # min_x = min([min(p) for p in pos])
        # max_x = max([max(p) for p in pos])
        # span = max_x - min_x
        # for i in range(num_episodes):
        #     plt.subplot(num_episodes, 1, i + 1)
        #     # (not plotting the last step because this gives weird artefacts)
        #     plt.plot(pos[i][:-1], range(len(pos[i][:-1])), 'k')
        #     plt.title('task: {}'.format(task), fontsize=15)
        #     plt.ylabel('steps (ep {})'.format(i), fontsize=15)
        #     if i == num_episodes - 1:
        #         plt.xlabel('position', fontsize=15)
        #     # else:
        #     #     plt.xticks([])
        #     plt.xlim(min_x - 0.05 * span, max_x + 0.05 * span)
        #     plt.plot([0, 0], [200, 200], 'b--', alpha=0.2)
        # plt.tight_layout()
        # if image_folder is not None:
        #     plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx))
        #     plt.close()
        # else:
        #     plt.show()

        # if not return_pos:
        #     return episode_latent_means, episode_latent_logvars, \
        #            episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
        #            episode_returns
        # else:
        #     return episode_latent_means, episode_latent_logvars, \
        #            episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
        #            episode_returns, pos
