import numpy as np
import torch
import utils

class ReplayBuffer(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, action_shape, capacity, device, window=1, max_episode_len=1000):
        self.capacity = capacity
        self.device = device

        # the proprioceptive obs is stored as float32, pixels obs as uint8
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8

        self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.extrinsic_rewards = np.empty((capacity, 1), dtype=np.float32)
        self.intrinsic_rewards = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32)
        self.div = np.empty((capacity, 1), dtype=np.float32)
        self.infos = []
        self.window = window
        self.max_episode_len = max_episode_len

        # Keep track of episode info
        self.current_episode_reward = 0
        self.current_episode_start_idx = 0
        self.anchor_return = None
        self.threshold_ratio = 0.9

        self.idx = 0
        self.last_save = 0
        self.full = False

        self.training = True
        self.is_diverse = 0
        self.num_episodes = 0 
        self.total_episodes = 0

    def __len__(self):
        return self.capacity if self.full else self.idx

    def add(self, obs, action, ext_reward, int_reward, next_obs, done, done_no_max, info, anchor_return=None):
        if anchor_return is not None:
            self.anchor_return = anchor_return

        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.extrinsic_rewards[self.idx], ext_reward)
        np.copyto(self.intrinsic_rewards[self.idx], int_reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.not_dones_no_max[self.idx], not done_no_max)

        # Default diversity bonus to False until episode ends
        np.copyto(self.div[self.idx], 0)

        if self.idx == len(self.infos):
            self.infos.append(info)
        else:
            self.infos[self.idx] = info

        # Episode tracking
        if self.idx == 0 or self.not_dones[self.idx-1] == 0:  # New episode
            self.current_episode_start_idx = self.idx
            self.current_episode_reward = 0

        self.current_episode_reward += ext_reward

        if done and self.training:
            is_diverse = (self.anchor_return is None or 
                         self.current_episode_reward >= self.anchor_return * self.threshold_ratio)
            
            # Update diversity flags for entire episode
            episode_length = self.idx - self.current_episode_start_idx + 1
            for i in range(episode_length):
                idx = (self.current_episode_start_idx + i) % self.capacity
                np.copyto(self.div[idx], float(is_diverse))

            # Reset episode tracking
            self.current_episode_reward = 0

            if anchor_return != 1e3:
                self.num_episodes += 1
                self.is_diverse += is_diverse
            self.total_episodes += 1

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0
    
    def add_batch(self, obs, action, ext_reward, int_reward, next_obs, done, done_no_max):
        
        next_index = self.idx + self.window
        if next_index >= self.capacity:
            self.full = True
            maximum_index = self.capacity - self.idx
            np.copyto(self.obses[self.idx:self.capacity], obs[:maximum_index])
            np.copyto(self.actions[self.idx:self.capacity], action[:maximum_index])
            np.copyto(self.extrinsic_rewards[self.idx:self.capacity], ext_reward[:maximum_index])
            np.copyto(self.intrinsic_rewards[self.idx:self.capacity], int_reward[:maximum_index])
            np.copyto(self.next_obses[self.idx:self.capacity], next_obs[:maximum_index])
            np.copyto(self.not_dones[self.idx:self.capacity], done[:maximum_index] <= 0)
            np.copyto(self.not_dones_no_max[self.idx:self.capacity], done_no_max[:maximum_index] <= 0)
            remain = self.window - (maximum_index)
            if remain > 0:
                np.copyto(self.obses[0:remain], obs[maximum_index:])
                np.copyto(self.actions[0:remain], action[maximum_index:])
                np.copyto(self.extrinsic_rewards[0:remain], ext_reward[maximum_index:])
                np.copyto(self.intrinsic_rewards[0:remain], int_reward[maximum_index:])
                np.copyto(self.next_obses[0:remain], next_obs[maximum_index:])
                np.copyto(self.not_dones[0:remain], done[maximum_index:] <= 0)
                np.copyto(self.not_dones_no_max[0:remain], done_no_max[maximum_index:] <= 0)
            self.idx = remain
        else:
            np.copyto(self.obses[self.idx:next_index], obs)
            np.copyto(self.actions[self.idx:next_index], action)
            np.copyto(self.extrinsic_rewards[self.idx:next_index], ext_reward)
            np.copyto(self.intrinsic_rewards[self.idx:next_index], int_reward)
            np.copyto(self.next_obses[self.idx:next_index], next_obs)
            np.copyto(self.not_dones[self.idx:next_index], done <= 0)
            np.copyto(self.not_dones_no_max[self.idx:next_index], done_no_max <= 0)
            self.idx = next_index
        
    def relabel_with_predictor(self, predictor, discriminator, agent_index):
        batch_size = 200
        total_iter = int(self.idx/batch_size)
        
        if self.idx > batch_size*total_iter:
            total_iter += 1
            
        for index in range(total_iter):
            last_index = (index+1)*batch_size
            if (index+1)*batch_size > self.idx:
                last_index = self.idx
                
            obses = self.obses[index*batch_size:last_index]
            actions = self.actions[index*batch_size:last_index]
            inputs = np.concatenate([obses, actions], axis=-1)
            
            pred_reward = predictor.r_hat_batch(inputs)
            explore_bonus = np.zeros_like(pred_reward) # Bonus is set to 0 for old trajectories=
            
            self.extrinsic_rewards[index*batch_size:last_index] = pred_reward
            self.intrinsic_rewards[index*batch_size:last_index] = explore_bonus
            
    def sample(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        ext_rewards = torch.as_tensor(self.extrinsic_rewards[idxs], device=self.device)
        int_rewards = torch.as_tensor(self.intrinsic_rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs],
                                           device=self.device)

        return obses, actions, ext_rewards, int_rewards, next_obses, not_dones, not_dones_no_max
    
    def sample_combine(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        ext_rewards = torch.as_tensor(self.extrinsic_rewards[idxs], device=self.device)
        int_rewards = torch.as_tensor(self.intrinsic_rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs],
                                           device=self.device)
        
        if self.full:
            full_obs = self.obses
        else:
            full_obs = self.obses[: self.idx]
        full_idxs = np.random.choice(full_obs.shape[0], size=512, replace=False)
        full_obs = torch.as_tensor(full_obs[full_idxs], device=self.device)

        return obses, full_obs, actions, ext_rewards, int_rewards, next_obses, not_dones, not_dones_no_max
    
    def sample_onpolicy(self, batch_size, size):
        idxs = np.random.randint(self.idx-size*self.max_episode_len, self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        ext_rewards = torch.as_tensor(self.extrinsic_rewards[idxs], device=self.device)
        int_rewards = torch.as_tensor(self.intrinsic_rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs],
                                           device=self.device)

        return obses, actions, ext_rewards, int_rewards, next_obses, not_dones, not_dones_no_max

    def sample_smerl(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        ext_rewards = torch.as_tensor(self.extrinsic_rewards[idxs], device=self.device)
        int_rewards = torch.as_tensor(self.intrinsic_rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs],
                                           device=self.device)
        div = torch.as_tensor(self.div[idxs],
                                           device=self.device)

        return obses, actions, ext_rewards, int_rewards, next_obses, not_dones, not_dones_no_max, div

    def sample_onpolicy_smerl(self, batch_size, size):
        idxs = np.random.randint(self.idx-size*self.max_episode_len, self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        ext_rewards = torch.as_tensor(self.extrinsic_rewards[idxs], device=self.device)
        int_rewards = torch.as_tensor(self.intrinsic_rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs],
                                           device=self.device)
        div = torch.as_tensor(self.div[idxs],
                                           device=self.device)

        return obses, actions, ext_rewards, int_rewards, next_obses, not_dones, not_dones_no_max, div
    
    def sample_state_ent(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        ext_rewards = torch.as_tensor(self.extrinsic_rewards[idxs], device=self.device)
        int_rewards = torch.as_tensor(self.intrinsic_rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs],
                                           device=self.device)
        
        if self.full:
            full_obs = self.obses
        else:
            full_obs = self.obses[: self.idx]
        full_obs = torch.as_tensor(full_obs, device=self.device)
        
        return obses, full_obs, actions, ext_rewards, int_rewards, next_obses, not_dones, not_dones_no_max