import numpy as np
from scipy.linalg import block_diag
from utils import get_idxs_per_relation, get_idxs_per_object


class her_sampler:
    def __init__(self, args, reward_func=None):
        self.replay_strategy = args.replay_strategy
        self.replay_k = args.replay_k
        self.future_p = 1 - (1. / (1 + args.replay_k))
        self.reward_func = reward_func
        self.multi_criteria_her = args.multi_criteria_her

        self.obj_ind = np.array([np.arange(i * 3, (i + 1) * 3) for i in range(args.n_blocks)])
        self.semantic_ids = get_idxs_per_object(n=args.n_blocks)


    def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions

        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}

        transitions['anchor_g'] = transitions['g'].copy()
        # her idx
        if self.multi_criteria_her:
            for sub_goal in self.semantic_ids:
                her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p)

                # future goal selection
                if self.replay_strategy == 'final':
                    # fictive goal is the final achieved goal of the selected HER episodes
                    future_ag = episode_batch['ag'][episode_idxs[her_indexes],-1]
                else:
                    # sample future achieved goals
                    future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
                    future_offset = future_offset.astype(int)
                    future_t = (t_samples + 1 + future_offset)[her_indexes]

                    # fictive goals are the selected future achieved goals
                    future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
                transition_goals = transitions['g'][her_indexes]
                transition_goals[:, sub_goal] = future_ag[:, sub_goal]
                transitions['g'][her_indexes] = transition_goals
        else:
            her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p)
            n_replay = her_indexes[0].size
            future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
            future_offset = future_offset.astype(int)
            future_t = (t_samples + 1 + future_offset)[her_indexes]

            # replace goal with achieved goal
            future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
            transitions['g'][her_indexes] = future_ag
            # to get the params to re-compute reward
        transitions['r'] = np.expand_dims(np.array([self.compute_reward_masks(ag_next, g) for ag_next, g in zip(transitions['ag_next'],
                                                    transitions['g'])]), 1)
        
        transitions['anchor_r'] = np.expand_dims(np.array([self.compute_reward_masks(ag_next, g) for ag_next, g in zip(transitions['ag_next'],
                                                    transitions['anchor_g'])]), 1)

        return transitions


    def sample_transitions(self, episode_batch, batch_size_in_transitions):
        """ Sample transitions from batch of episodes WITHOUT applying Hindsight Experience Replay """
        T = episode_batch['actions'].shape[1]
        rollout_batch_size = episode_batch['actions'].shape[0]
        batch_size = batch_size_in_transitions

        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)

        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys() if key!='her'}

        transitions['anchor_g'] = transitions['g'].copy()
    
        transitions['r'] = np.expand_dims(np.array([self.compute_reward_masks(ag_next, g) for ag_next, g in zip(transitions['ag_next'],
                                                    transitions['g'])]), 1)
        
        transitions['anchor_r'] = np.expand_dims(np.array([self.compute_reward_masks(ag_next, g) for ag_next, g in zip(transitions['ag_next'],
                                                    transitions['anchor_g'])]), 1)

        return transitions

    def compute_reward_masks(self, ag, g):
        reward = 0.
        for subgoal in self.semantic_ids:
            if (ag[subgoal] == g[subgoal]).all():
                reward = reward + 1.
        return reward