import random
import numpy as np
import os
import pickle

class ReplayMemory:
    def __init__(self, capacity, seed):
        random.seed(seed)
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, radius_value, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, radius_value, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, radius, next_state, done = map(lambda x: np.stack(x).astype(np.float32), zip(*batch))
        return state, action, reward, radius, next_state, done

    def __len__(self):
        return len(self.buffer)

    def save_buffer(self, env_name, suffix="", save_path=None):
        if not os.path.exists('checkpoints/'):
            os.makedirs('checkpoints/')

        if save_path is None:
            save_path = "checkpoints/sac_buffer_{}_{}".format(env_name, suffix)
        print('Saving buffer to {}'.format(save_path))

        with open(save_path, 'wb') as f:
            pickle.dump(self.buffer, f)

    def load_buffer(self, save_path):
        print('Loading buffer from {}'.format(save_path))

        with open(save_path, "rb") as f:
            self.buffer = pickle.load(f)
            self.position = len(self.buffer) % self.capacity


class TrajectoryReplayMemory:
    def __init__(self, capacity, max_trajectory_length, seed=None):
        if seed is not None:
            np.random.seed(seed)
        self.capacity = capacity
        self.max_trajectory_length = max_trajectory_length
        self.buffer = []
        self.position = 0

    def push(self, trajectory):
        # trajectory is list of following tuples : (state, radius)
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = trajectory
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        # Randomly sample batch_size trajectories from the buffer
        # trajectories is a list of lists: [[[state, radius], ...], ...]
        trajectories = random.sample(self.buffer, batch_size)
        
        # Initialize lists to hold the states and radius components
        states, radius = [], []

        # Special value to use for padding
        padding_value = -1

        for traj in trajectories:
            _state, _radius = zip(*traj)
            # Add padding if the trajectory is shorter than the maximum length
            if len(_state) < self.max_trajectory_length:
                pad_length = self.max_trajectory_length - len(_state)
                _state = list(_state) + [np.zeros_like(_state[0])] * pad_length
                _radius = list(_radius) + [padding_value] * pad_length
            
            states.append(_state)
            radius.append(_radius)

        # Convert lists to numpy arrays and add batch dimension
        states = np.array(states, dtype=np.float32)
        radius = np.array(radius, dtype=np.float32)  # Assuming radius values are float32 as well

        # Return arrays with dimensions: [minibatch, len_traj, feature], [minibatch, len_traj]
        return states, radius

    def get_actual_lengths(radius, padding_value=-1):
        """Calculate the actual lengths of each trajectory in the batch"""
        lengths = np.sum(radius != padding_value, axis=1)
        return lengths
    
    def __len__(self):
        return len(self.buffer)
    

