# import matplotlib as mpl
import random
import numpy as np
# import torch

# import matplotlib.pyplot as plt
# import seaborn as sns

from gymnasium import Env
from gymnasium import spaces

from src.utils.misc import set_seed


# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def semi_circle_goal_sampler():
    r = 1.0
    angle = random.uniform(0, np.pi)
    #angle = 0.75*np.pi
    goal = r * np.array((np.cos(angle), np.sin(angle)))
    return goal


class PointEnv(Env):
    """
    point robot on a 2-D plane with position control
    tasks (aka goals) are positions on the plane

     - tasks sampled from unit square
     - reward is L2 distance
    """

    def __init__(self, goal_pos=None):
        self.goal_sampler = semi_circle_goal_sampler

        self.reset_task(goal_pos)
        self.task_dim = 2
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2,))
        # we convert the actions from [-1, 1] to [-0.1, 0.1] in the step() function
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))

    def sample_task(self):
        goal = self.goal_sampler()
        return goal

    def set_task(self, task):
        self.goal = task

    def get_task(self):
        return self.goal

    def reset_task(self, task=None):
        if task is None:
            task = self.sample_task()
        self.set_task(task)
        return task

    def reset_model(self):
        self._state = np.zeros(2)
        return self._get_obs()

    def reset(self, seed=0, options=None):
        super().reset(seed=seed, options=options)
        return self.reset_model(), {}

    def pos_to_state(self, goal):
        return tuple(goal.tolist())

    def _get_obs(self):
        return np.copy(self._state)

    def step(self, action):

        action = np.clip(action, self.action_space.low, self.action_space.high)
        # assert self.action_space.contains(action)
        # assert action

        self._state = self._state + 0.1 * action
        reward = - np.linalg.norm(self._state - self.goal, ord=2)
        done = False
        ob = self._get_obs()
        info = {'task': self.get_task()}
        return ob, reward, done, done, info

    # def visualise_behaviour(self,
    #                         env,
    #                         args,
    #                         policy,
    #                         iter_idx,
    #                         encoder=None,
    #                         image_folder=None,
    #                         return_pos=False,
    #                         **kwargs,
    #                         ):
    #
    #     num_episodes = args.max_rollouts_per_task
    #
    #     # --- initialise things we want to keep track of ---
    #
    #     episode_prev_obs = [[] for _ in range(num_episodes)]
    #     episode_next_obs = [[] for _ in range(num_episodes)]
    #     episode_actions = [[] for _ in range(num_episodes)]
    #     episode_rewards = [[] for _ in range(num_episodes)]
    #
    #     episode_returns = []
    #     episode_lengths = []
    #
    #     if encoder is not None:
    #         episode_hidden_states = [[] for _ in range(num_episodes)]
    #     else:
    #         episode_hidden_states = None
    #
    #     # --- roll out policy ---
    #
    #     # (re)set environment
    #     env.reset_task()
    #     state, task = utl.reset_env(env, args)
    #     start_obs_raw = state.clone()
    #     task = task.view(-1) if task is not None else None
    #
    #     # initialise actions and rewards (used as initial input to policy if we have a recurrent policy)
    #     if hasattr(args, 'hidden_size'):
    #         hidden_state = torch.zeros((1, args.hidden_size)).to(device)
    #     else:
    #         hidden_state = None
    #
    #     # keep track of what task we're in and the position of the cheetah
    #     pos = [[] for _ in range(args.max_rollouts_per_task)]
    #     start_pos = state
    #
    #     for episode_idx in range(num_episodes):
    #
    #         curr_rollout_rew = []
    #         pos[episode_idx].append(start_pos[0])
    #
    #         if episode_idx == 0:
    #             if encoder is not None:
    #                 # reset to prior
    #                 current_hidden_state = encoder.prior(1)
    #                 current_hidden_state.to(device)
    #         episode_hidden_states[episode_idx].append(current_hidden_state[0].clone())
    #
    #         for step_idx in range(1, env._max_episode_steps + 1):
    #
    #             if step_idx == 1:
    #                 episode_prev_obs[episode_idx].append(start_obs_raw.clone())
    #             else:
    #                 episode_prev_obs[episode_idx].append(state.clone())
    #                 # act
    #             _, action = utl.select_action_cpc(args=args, policy=policy, deterministic=True,
    #                                               hidden_latent=current_hidden_state.squeeze(0), state=state, task=task)
    #
    #             (state, task), (rew, rew_normalised), done, info = utl.env_step(env, action, args)
    #             state = state.float().reshape((1, -1)).to(device)
    #             task = task.view(-1) if task is not None else None
    #
    #             # keep track of position
    #             pos[episode_idx].append(state[0])
    #
    #             if encoder is not None:
    #                 # update task embedding
    #                 current_hidden_state = encoder(
    #                     action.reshape(1, -1).float().to(device), state, rew.reshape(1, -1).float().to(device),
    #                     current_hidden_state, return_prior=False)
    #
    #                 episode_hidden_states[episode_idx].append(current_hidden_state[0].clone())
    #
    #             episode_next_obs[episode_idx].append(state.clone())
    #             episode_rewards[episode_idx].append(rew.clone())
    #             episode_actions[episode_idx].append(action.clone())
    #
    #             if info[0]['done_mdp'] and not done:
    #                 start_obs_raw = info[0]['start_state']
    #                 start_obs_raw = torch.from_numpy(start_obs_raw).float().reshape((1, -1)).to(device)
    #                 start_pos = start_obs_raw
    #                 break
    #
    #         episode_returns.append(sum(curr_rollout_rew))
    #         episode_lengths.append(step_idx)
    #
    #     # clean up
    #     if encoder is not None:
    #         episode_hidden_states = [torch.stack(e) for e in episode_hidden_states]
    #     episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
    #     episode_next_obs = [torch.cat(e) for e in episode_next_obs]
    #     episode_actions = [torch.stack(e) for e in episode_actions]
    #     episode_rewards = [torch.cat(e) for e in episode_rewards]
    #
    #     figsize = (2.5 * 3, 2.5 * 3)
    #     figure, axis = plt.subplots(1, 1, figsize=figsize)
    #     xlim = (-1.3, 1.3)
    #     # if self.goal_sampler == semi_circle_goal_sampler:
    #     #     ylim = (-0.3, 1.3)
    #     # else:
    #     ylim = (-1.3, 1.3)
    #     color_map = mpl.colors.ListedColormap(sns.color_palette("husl", num_episodes))
    #
    #     observations = torch.stack([episode_prev_obs[i] for i in range(num_episodes)]).cpu().numpy()
    #     curr_task = env.get_task()[0]
    #
    #     # plot goal
    #     axis.scatter(curr_task[1],curr_task[0], marker='x', color='k', s=50)
    #     # radius where we get reward
    #     if hasattr(self, 'goal_radius'):
    #         circle1 = plt.Circle(curr_task[::-1], self.goal_radius, color='c', alpha=0.2, edgecolor='none')
    #         plt.gca().add_artist(circle1)
    #
    #     for i in range(num_episodes):
    #         color = color_map(i)
    #         path = observations[i]
    #
    #         # plot (semi-)circle
    #         r = 1.0
    #         if self.goal_sampler == semi_circle_goal_sampler:
    #             angle = np.linspace(0, np.pi, 100)
    #         else:
    #             angle = np.linspace(0, 2 * np.pi, 100)
    #         goal_range = r * np.array((np.cos(angle), np.sin(angle)))
    #         # plt.plot(goal_range[0], goal_range[1], 'k--', alpha=0.1)
    #         plt.plot(goal_range[1],goal_range[0], 'k--', alpha=0.1)
    #
    #         # plot trajectory
    #         axis.plot(path[:, 1],path[:, 0], '-',marker='o', color=color, label=f'Episode {i}')
    #         axis.scatter(path[0, 1],path[0, 0], marker='.', color=color, s=50)
    #         # if i == 1 and kwargs['belief_evaluator'] is not None:
    #         #     r = 1
    #         #     num_points_semicircle = 50
    #         #     angles = np.linspace(0, np.pi, num=num_points_semicircle)
    #         #     x, y = r * np.cos(angles), r * np.sin(angles)
    #         #     belief = 1. - torch.sigmoid(
    #         #         kwargs['belief_evaluator'](episode_hidden_states[i][0])).detach().cpu().numpy().flatten()
    #         #     plt.scatter(x, y, c=belief, cmap='gray')
    #     plt.ylim(xlim)
    #     plt.xlim(ylim)
    #     plt.xticks([])
    #     plt.yticks([])
    #     # from matplotlib import transforms
    #     # base = plt.gca().transData
    #     # rot = transforms.Affine2D().rotate_deg(90)
    #     plt.title('Sparse Pointrobot')
    #     plt.legend()
    #     plt.tight_layout()
    #     # kwargs['logger'].add_figure('belief', figure, iter_idx)
    #     if image_folder is not None:
    #         plt.savefig('{}/{}_behaviour'.format(image_folder, iter_idx))#, dpi=300, bbox_inches='tight')
    #         plt.close()
    #     else:
    #         plt.show()
    #
    #     plt_rew = [episode_rewards[i][:episode_lengths[i]] for i in range(len(episode_rewards))]
    #     plt.plot(torch.cat(plt_rew).view(-1).cpu().numpy())
    #     plt.xlabel('env step')
    #     plt.ylabel('reward per step')
    #     plt.tight_layout()
    #     # if image_folder is not None:
    #     #     plt.savefig('{}/{}_rewards.png'.format(image_folder, iter_idx), dpi=300, bbox_inches='tight')
    #     #     plt.close()
    #     # else:
    #     #     plt.show()
    #
    #     if not return_pos:
    #         return episode_hidden_states, episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
    #                episode_returns
    #     else:
    #         return episode_hidden_states, \
    #                episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
    #                episode_returns, pos


class SparsePointEnv(PointEnv):
    """ Reward is L2 distance given only within goal radius """

    def __init__(self, goal_radius=0.2, goal_pos=None):
        super().__init__(goal_pos=goal_pos)
        self.goal_radius = goal_radius
        self.reset_task(goal_pos)

    def sparsify_rewards(self, r):
        ''' zero out rewards when outside the goal radius '''
        mask = (r >= -self.goal_radius).astype(np.float32)
        r = r * mask
        return r

    def reset_model(self):
        self._state = np.array([0, 0])
        return self._get_obs()

    def step(self, action):
        ob, reward, terminate, trunc, d = super().step(action)
        sparse_reward = self.sparsify_rewards(reward)
        # make sparse rewards positive
        if reward >= -self.goal_radius:
            sparse_reward += 1

        d.update({'sparse_reward': sparse_reward})
        d.update({'dense_reward': reward})
        return ob, sparse_reward, terminate, trunc, d


def train_test_goals_sc(num_test_goals, seed):
    num_train_goals = 1000
    set_seed(seed)
    train_goals = np.array([semi_circle_goal_sampler() for _ in range(num_train_goals)])
    set_seed(seed + 1)
    test_goals = np.array([semi_circle_goal_sampler() for _ in range(num_test_goals)])
    return train_goals, test_goals


if __name__ == "__main__":
    import gym
    gym.register(
        id="Semi-Circle-Sparse-v0",
        entry_point="src.envs.semi_circle:SparsePointEnv",
        max_episode_steps=60,  # from VariBAD
        kwargs={
            "goal_radius": 0.2
        }
    )

    r = 1.0
    angle = np.pi / 2
    goal = r * np.array((np.cos(angle), np.sin(angle)))
    env = gym.make("Semi-Circle-Sparse-v0", goal_pos=goal)

    obs, _ = env.reset()

    print("GOAL", goal, env.goal)
    done = False
    trunc = False

    while not done and not trunc:
        print("OBS:", obs)
        action = np.array(list(map(float, input("ACT:").split())))
        # print(action)
        obs, reward, done, trunc, info = env.step(action)
        print("REW:", reward)


