import numpy as np
import torch
from collections import deque, namedtuple
import pickle
import time

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward')) # typical menmory that adapt to the RL setting as well

class Memory:
      # different from RL, the size is the number of episodes but not the number of transitions
      # for rewards, we store the rtg of the trajectory so that we can use PPO to update the policy
      # each trajectory is hashed so that no duplicate trajectories are stored
      def __init__(self, size, device = 'cpu'):
            self.memory = [None] * size  # Preallocate a fixed-size list
            self.rewards = np.zeros(size, dtype=np.float64)  # Initialize as NumPy array, float64 for not losing any precision in this stage
            self.augmented_rewards = np.ones(size, dtype=np.float64)
            self.losses = np.full(size, 1, dtype=np.float64)  # Initialize with priority

            # Note: this are used for validation and not bounded by size!
            self.visited_end_states = [] 
            self.visited_rewards = []
            self.visited_act_seqs = []
            self.train_PF = []
            self.train_PB = []
            self.train_losses = []
            self.train_rews = []
            self.train_Z = []
            self.visited_timestamp = []
            
            self.size = size
            self.device = device
            self.max_priority = 10
            self.threshold_for_biased_sampling = 0
            self.threshold_for_priority_sampling = 0
            self.threshold_needs_update = False

            self.current_index = 0  # Tracks the circular position
            self.num_elements = 0  # Tracks the number of valid elements

      def push_traj(self, state_buffer, action_buffer, next_state_buffer, reward_buffer, augmented_reward):
            traj = [
                  Transition(state, action, next_state, reward)
                  for state, action, next_state, reward in zip(state_buffer, action_buffer, next_state_buffer, reward_buffer)
            ]
            # Add the trajectory at the current index
            self.memory[self.current_index] = traj

            # Update rewards and losses
            self.rewards[self.current_index] = reward_buffer[-1]  # Store final reward
            self.losses[self.current_index] = 1  # Initialize with max priority

            # Update augmented rewards
            self.augmented_rewards[self.current_index] = augmented_reward

            # Update the circular index and number of elements
            self.current_index = (self.current_index + 1) % self.size
            if self.num_elements < self.size:
                  self.num_elements += 1
            
            self.visited_end_states.append(next_state_buffer[-1])
            self.visited_rewards.append(reward_buffer[-1])
            self.visited_timestamp.append(time.time())

            self.threshold_needs_update = True

      def push_trajs(self, state_buffer, action_buffer, next_state_buffer, reward_buffer, length_buffer, augmented_reward_buffer): # push multiple trajectories
            for i in range(len(length_buffer)):
                  self.push_traj(
                        [state_buffer[j][i] for j in range(length_buffer[i])],
                        [action_buffer[j][i] for j in range(length_buffer[i])],
                        [next_state_buffer[j][i] for j in range(length_buffer[i])],
                        [reward_buffer[j][i] for j in range(length_buffer[i])],
                        augmented_reward_buffer[i]
                  )

      def push_train_logs(self, pf, pb, loss, rews, Z):
            self.train_PF.append(pf)
            self.train_PB.append(pb)
            self.train_losses.append(loss)
            self.train_rews.append(rews)
            self.train_Z.append(Z)

      def update_threshold(self, batch_size):
            self.top_idx, self.bottom_idx = [], []
            
            if self.num_elements < 10:
                  return
            self.threshold_for_biased_sampling = np.percentile(self.rewards[:self.num_elements], 90)
            self.top_idx =  np.where(self.rewards[:self.num_elements] >= self.threshold_for_biased_sampling)[0]
            self.bottom_idx = np.where(self.rewards[:self.num_elements] < self.threshold_for_biased_sampling)[0]
            if len(self.top_idx) < batch_size // 2:
                  self.top_idx = np.tile(self.top_idx, (batch_size // 2) // len(self.top_idx) + 1)
            
      def sample(self, batch_size):
            # sample a batch of full trajectories
            idx = np.random.choice(self.num_elements, batch_size, replace = False)
            return self.sample_from_indices(idx)
      
      def biased_sample(self, batch_size):
            """
                  Select trajecotries for offline training, returns a batch of trajectories
                  Draw 50% from top 10% of the rewards, and 50% from bottom 90%.
            """
            # edge case: top_idx or bottom_idx is empty
            if len(self.top_idx) > 0 and len(self.bottom_idx) > 0:
                  idx1 = np.random.choice(self.top_idx, batch_size // 2, replace=False)
                  idx2 = np.random.choice(self.bottom_idx, batch_size // 2, replace=False)
                  idx = np.concatenate([idx1, idx2])
            else:
                  idx = np.random.choice(self.num_elements, batch_size, replace=False)

            return self.sample_from_indices(idx)
      
      def generalized_biased_sample(self, batch_size):
            """ biased sample could not properly deal with the case in which 90% percentile is still zero reward, thie one can"""
            priority = np.zeros(self.num_elements)
            top_10 = self.rewards[:self.num_elements]>self.threshold_for_priority_sampling
            priority[top_10] = 1 #self.rewards[:self.num_elements][top_10]
            # spread it to the rest of the records
            bottom_90 = self.rewards[:self.num_elements]<= self.threshold_for_priority_sampling
            priority[bottom_90] = (np.sum(top_10) + 1e-8)/(np.sum(bottom_90) + 1e-8)
            priority /= priority.sum()
            idx = np.random.choice(self.num_elements, batch_size, p = priority, replace=False)
            return self.sample_from_indices(idx)

                  
      def mixed_priority_sample(self, batch_size):
            """
                  Reward-biased loss-prioritized replay sampling
            """
            priority = self.rewards[:self.num_elements] * self.losses[:self.num_elements]
            priority += np.mean(priority)

            # print("High weight", np.sum(priority[self.rewards[:self.num_elements]>self.threshold_for_priority_sampling]))
            # print("Low weight", np.sum(priority[self.rewards[:self.num_elements]<=self.threshold_for_priority_sampling]))

            priority /= priority.sum()

            idx = np.random.choice(self.num_elements, batch_size, p = priority, replace=False)
            return self.sample_from_indices(idx)
      
      def sample_from_indices(self, idx):
            sample_traj = [self.memory[i] for i in idx]
            state_batches = [torch.tensor(np.array([t[0] for t in traj]), dtype=torch.float, device = self.device) for traj in sample_traj]
            action_batches = [torch.tensor(np.array([t[1] for t in traj]), dtype=torch.float, device = self.device) for traj in sample_traj]
            next_state_batches = [torch.tensor(np.array([t[2] for t in traj]), dtype=torch.float, device = self.device) for traj in sample_traj]
            reward_batches = [torch.tensor(np.array([traj[-1][3]]), dtype=torch.float, device = self.device) for traj in sample_traj]
            augmented_reward_batches = [torch.tensor(np.array([self.augmented_rewards[i]]), dtype=torch.float, device = self.device) for i in idx]
            return state_batches, action_batches, next_state_batches, reward_batches, augmented_reward_batches, idx

      
      def latest_sample(self, n = 1, s = 0): # get the latest n trajectories
            # Handle the wrap-around for the circular buffer
            start_idx = (self.current_index - s) % self.size
            if start_idx + n <= self.size:
                  sample_traj = [self.memory[i] for i in range(start_idx, start_idx + n)]
            else:
                  sample_traj = (
                        [self.memory[i] for i in range(start_idx, self.size)] +
                        [self.memory[i] for i in range(0, (start_idx + n) % self.size)]
                  )
            # process the input
            state_batches = [torch.tensor(np.array([t[0] for t in traj]), dtype=torch.float, device = self.device) for traj in sample_traj]
            action_batches = [torch.tensor(np.array([t[1] for t in traj]), dtype=torch.float, device = self.device) for traj in sample_traj]
            next_state_batches = [torch.tensor(np.array([t[2] for t in traj]), dtype=torch.float, device = self.device) for traj in sample_traj]
            reward_batches = [torch.tensor(np.array([traj[-1][3]]), dtype=torch.float, device = self.device) for traj in sample_traj]
            return state_batches, action_batches, next_state_batches, reward_batches
            
      def update_priority(self, idx, priority):
            # Slice rewards and losses once
            rewards = self.rewards[:self.num_elements]
            losses = self.losses[:self.num_elements]

            if self.threshold_needs_update:
                  self.threshold_for_priority_sampling = np.percentile(rewards, 90)
                  self.threshold_needs_update = False

            mask = rewards > self.threshold_for_priority_sampling
            losses[mask] = losses[mask] * 0.95 + 0.05  # Decay towards 1 for selected trajectories

            # normalize the priority to 0.1 to 10
            priority_min, priority_max = priority.min(), priority.max()
            norm_priority = ((priority - priority_min) / (priority_max - priority_min + 1e-8)) * 9.9 + 0.1

            self.losses[idx] = norm_priority

      def set_max_priority(self, idx):
            self.losses[idx] = self.max_priority

      def clear(self):
            self.memory = [None] * self.size  # Preallocate a fixed-size list

      def save(self, directory): # since we don't support resume training, we don't need to save the full memory
            # save the meomory
            with open(directory, 'wb') as f:
                  pickle.dump([self.visited_end_states, self.visited_rewards, self.train_PF, self.train_PB, self.train_losses, self.train_rews, self.train_Z, self.visited_timestamp], f)
                  # pickle.dump([self.memory, self.rewards, self.losses, self.visited_end_states], f)
      
      def load(self, directory):
            # load the memory
            with open(directory, 'rb') as f:
                  self.visited_end_states, self.visited_rewards, self.train_PF, self.train_PB, self.train_losses, self.train_rews, self.train_Z, self.visited_timestamp = pickle.load(f)
                  # self.memory, self.rewards, self.losses, self.visited_end_states = pickle.load(f)
      
      def __len__(self):
            return self.num_elements
      