import numpy as np
import torch

from utils import helpers as utl

from utils.helpers import get_device


class RolloutStorageVAE(object):
    def __init__(self, num_processes, max_trajectory_len, zero_pad, max_num_rollouts,
                 state_dim, action_dim, vae_buffer_add_thresh, task_dim, error_handling=False):
        """
        Store everything that is needed for the VAE update
        :param num_processes:
        """

        self.obs_dim = state_dim
        self.action_dim = action_dim
        self.task_dim = task_dim

        self.vae_buffer_add_thresh = vae_buffer_add_thresh  # prob of adding new trajectories
        self.max_buffer_size = max_num_rollouts  # maximum buffer len (number of trajectories)
        self.insert_idx = 0  # at which index we're currently inserting new data
        self.buffer_len = 0  # how much of the buffer has been filled

        # how long a trajectory can be at max (horizon)
        self.max_traj_len = max_trajectory_len+1 if error_handling else max_trajectory_len
        # whether to zero-pad to maximum length (zero's at the end!)
        self.zero_pad = zero_pad

        # buffers for completed rollouts (stored on CPU)
        if self.max_buffer_size > 0:
            self.prev_state = torch.zeros((self.max_traj_len, self.max_buffer_size, state_dim))
            self.next_state = torch.zeros((self.max_traj_len, self.max_buffer_size, state_dim))
            self.actions = torch.zeros((self.max_traj_len, self.max_buffer_size, action_dim))
            self.rewards = torch.zeros((self.max_traj_len, self.max_buffer_size, 1))
            if task_dim is not None:
                self.tasks = torch.zeros((self.max_buffer_size, task_dim))
            else:
                self.tasks = None
            self.trajectory_lens = [0] * self.max_buffer_size

        # storage for each running process (stored on GPU)
        self.num_processes = num_processes
        self.curr_timestep = torch.zeros((num_processes)).long()  # count environment steps so we know where to insert
        self.running_prev_state = torch.zeros((self.max_traj_len, num_processes, state_dim)).to(get_device())  # for each episode will have obs 0...N-1
        self.running_next_state = torch.zeros((self.max_traj_len, num_processes, state_dim)).to(get_device())  # for each episode will have obs 1...N
        self.running_rewards = torch.zeros((self.max_traj_len, num_processes, 1)).to(get_device())
        self.running_actions = torch.zeros((self.max_traj_len, num_processes, action_dim)).to(get_device())
        if task_dim is not None:
            self.running_tasks = torch.zeros((num_processes, task_dim)).to(get_device())
        else:
            self.running_tasks = None

    def get_running_batch(self):
        """
        Returns the batch of data from the current running environments
        (zero-padded to maximal trajectory length since different processes can have different trajectory lengths)
        :return:
        """
        return self.running_prev_state, self.running_next_state, self.running_actions, self.running_rewards, self.curr_timestep

    def insert(self, prev_state, actions, next_state, rewards, done, task):

        # ------------------------------------------------------------------
        # Guard-rail: if any process has already reached max_traj_len,
        #             force a reset for that process even if `done[i]` is 0.
        # ------------------------------------------------------------------
        
        # ==== DEBUG: entry point ====
        # print(f"[VAE Storage] INSERT called")
        # print(f"  curr_timestep: {self.curr_timestep.tolist()}")
        # print(f"  done flags:    {done.cpu().numpy().tolist()}")

        

        already_inserted = False
        if len(np.unique(self.curr_timestep)) == 1:
            #print(f"[VAE Storage] → vectorized branch (all timesteps equal: {self.curr_timestep[0]})")

            # Debug print statements
            #if self.curr_timestep[0] >= 2000:
            #    print(f"Current timestep: {self.curr_timestep[0]}")
            #    print(f"Shape of running_prev_state: {self.running_prev_state.shape}")
            #    print(f"Shape of prev_state: {self.prev_state.shape}")
            #    print(f"Max trajectory length: {self.max_traj_len}")
            #    print(f"Max buffer size (num trajectories): {self.max_buffer_size}")
            #    print(f"Current insert index: {self.insert_idx}")
            #    print(f"Buffer size (trajectories) after potential insert: {self.insert_idx + self.num_processes}")
                
            self.running_prev_state[self.curr_timestep[0]] = prev_state
            self.running_next_state[self.curr_timestep[0]] = next_state
            self.running_rewards[self.curr_timestep[0]] = rewards
            self.running_actions[self.curr_timestep[0]] = actions
            if task is not None:
                self.running_tasks = task
            self.curr_timestep += 1
            # Add a guardrail here to force a reset if timestep is too high
            overflow_mask = self.curr_timestep >= self.max_traj_len
            if overflow_mask.any():
                # mark those processes as done so the normal reset path executes
                #print(f"[VAE Storage: batch] → resetting overflow of {self.max_traj_len} max_traj_len at indices {[i for i, overflowed in enumerate(overflow_mask) if overflowed]}")
                done = done.clone()
                done[overflow_mask] = 1.0

            #print(f"[VAE Storage]   after vector insert, curr_timestep: {self.curr_timestep.tolist()}")
            already_inserted = True

        already_reset = False
        if done.sum() == self.num_processes:  # check if we can process the entire batch at once
            #print(f"[VAE Storage: batch] → batch‐wide reset (done.sum()=={int(done.sum())})")

            # add to permanent (up to max_buffer_len) buffer
            if self.max_buffer_size > 0:
                if self.vae_buffer_add_thresh >= np.random.uniform(0, 1):
                    # check where to insert data
                    if self.insert_idx + self.num_processes > self.max_buffer_size:
                        # keep track of how much we filled the buffer (for sampling from it)
                        self.buffer_len = self.insert_idx
                        # this will keep some entries at the end of the buffer without overwriting them,
                        # but the buffer is large enough to make this negligible
                        self.insert_idx = 0
                    else:
                        self.buffer_len = max(self.buffer_len, self.insert_idx)
                    # add; note: num trajectories are along dim=1,
                    # trajectory length along dim=0, to match pytorch RNN interface
                    self.prev_state[:, self.insert_idx:self.insert_idx + self.num_processes] = self.running_prev_state
                    self.next_state[:, self.insert_idx:self.insert_idx + self.num_processes] = self.running_next_state
                    self.actions[:, self.insert_idx:self.insert_idx+self.num_processes] = self.running_actions
                    self.rewards[:, self.insert_idx:self.insert_idx+self.num_processes] = self.running_rewards
                    if (self.tasks is not None) and (self.running_tasks is not None):
                        insert_shape = self.tasks[self.insert_idx:self.insert_idx+self.num_processes].shape
                        self.tasks[self.insert_idx:self.insert_idx+self.num_processes] = self.running_tasks.reshape(insert_shape)
                    self.trajectory_lens[self.insert_idx:self.insert_idx+self.num_processes] = self.curr_timestep.clone()
                    self.insert_idx += self.num_processes

            # empty running buffer
            self.running_prev_state *= 0
            self.running_next_state *= 0
            self.running_rewards *= 0
            self.running_actions *= 0
            if self.running_tasks is not None:
                self.running_tasks *= 0
            self.curr_timestep *= 0
            #print(f"[VAE Storage]   after batch reset, curr_timestep: {self.curr_timestep.tolist()}")
            already_reset = True

        if (not already_inserted) or (not already_reset):

            for i in range(self.num_processes):
                #print(f"[VAE Storage] per‐proc loop, i={i}, timestep={self.curr_timestep[i]}, done[i]={bool(done[i])}")

                if not already_inserted:
                    self.running_prev_state[self.curr_timestep[i], i] = prev_state[i]
                    self.running_next_state[self.curr_timestep[i], i] = next_state[i]
                    self.running_rewards[self.curr_timestep[i], i] = rewards[i]
                    self.running_actions[self.curr_timestep[i], i] = actions[i]
                    if self.running_tasks is not None:
                        self.running_tasks[i] = task[i]
                    self.curr_timestep[i] += 1
                    if self.curr_timestep[i] >= self.max_traj_len:
                        #print(f"[VAE Storage: individual] → resetting overflow of {self.max_traj_len} max_traj_len at indices {[i]}")
                        done[i] = 1.0
                    #print(f"[VAE Storage]   -> inserted proc {i}, new timestep {self.curr_timestep[i]}")

                if not already_reset:
                    # if we are at the end of a task, dump the data into the larger buffer
                    if done[i]:
                        #print(f"[VAE Storage]   -> proc {i} done, performing per‐proc reset")

                        # add to permanent (up to max_buffer_len) buffer
                        if self.max_buffer_size > 0:
                            if self.vae_buffer_add_thresh >= np.random.uniform(0, 1):
                                # check where to insert data
                                if self.insert_idx + 1 > self.max_buffer_size:
                                    # keep track of how much we filled the buffer (for sampling from it)
                                    self.buffer_len = self.insert_idx
                                    # this will keep some entries at the end of the buffer without overwriting them,
                                    # but the buffer is large enough to make this negligible
                                    self.insert_idx = 0
                                else:
                                    self.buffer_len = max(self.buffer_len, self.insert_idx)
                                # add; note: num trajectories are along dim=1,
                                # trajectory length along dim=0, to match pytorch RNN interface
                                self.prev_state[:, self.insert_idx] = self.running_prev_state[:, i].to('cpu')
                                self.next_state[:, self.insert_idx] = self.running_next_state[:, i].to('cpu')
                                self.actions[:, self.insert_idx] = self.running_actions[:, i].to('cpu')
                                self.rewards[:, self.insert_idx] = self.running_rewards[:, i].to('cpu')
                                if self.tasks is not None:
                                    self.tasks[self.insert_idx] = self.running_tasks[i].to('cpu')
                                self.trajectory_lens[self.insert_idx] = self.curr_timestep[i].clone()
                                self.insert_idx += 1

                        # empty running buffer
                        self.running_prev_state[:, i] *= 0
                        self.running_next_state[:, i] *= 0
                        self.running_rewards[:, i] *= 0
                        self.running_actions[:, i] *= 0
                        if self.running_tasks is not None:
                            self.running_tasks[i] *= 0
                        self.curr_timestep[i] = 0
                        #print(f"[VAE Storage]   -> after per‐proc reset, curr_timestep[{i}]=0")

                        
        #if self.curr_timestep[0] >= 2000:
                #print("uh oh this is bad — about to OOB!")
                #print(f"Final curr_timestep: {self.curr_timestep.tolist()}")

    def ready_for_update(self):
        return len(self) > 0

    def __len__(self):
        return self.buffer_len

    def get_batch(self, batchsize=5, replace=False):
        # TODO: check if we can get rid of num_enc_len and num_rollouts (call it batchsize instead)

        batchsize = min(self.buffer_len, batchsize)

        # select the indices for the processes from which we pick
        rollout_indices = np.random.choice(range(self.buffer_len), batchsize, replace=replace)
        # trajectory length of the individual rollouts we picked
        trajectory_lens = np.array(self.trajectory_lens)[rollout_indices]

        # select the rollouts we want
        prev_obs = self.prev_state[:, rollout_indices, :]
        next_obs = self.next_state[:, rollout_indices, :]
        actions = self.actions[:, rollout_indices, :]
        rewards = self.rewards[:, rollout_indices, :]
        if self.tasks is not None:
            tasks = self.tasks[rollout_indices].to(get_device())
        else:
            tasks = None

        return prev_obs.to(get_device()), next_obs.to(get_device()), actions.to(get_device()), \
               rewards.to(get_device()), tasks, trajectory_lens


# Simplified version of RolloutStorage designed for storing individual transitions, used for SAC/other basic off-policy methods
class TransitionStorage(object):
    def __init__(self, args, num_processes, max_buffer_size,
                 state_dim, action_dim, latent_dim, vae_buffer_add_thresh=1.0, batch_size=128):
        """
        Off-policy replay buffer based on storage_vae.py code structure.
        
        Instead of storing complete trajectories, this class stores transitions
        and allows sampling of random batches for off-policy training. 
        
        Parameters:
        - num_processes: number of parallel environments (not strictly necessary for off-policy, but we keep it)
        - max_num_rollouts: here interpreted as max buffer size (number of transitions)
        - state_dim: dimension of state
        - action_dim: dimension of action
        - task_dim: dimension of tasks, if available
        - vae_buffer_add_thresh: probability of adding a new transition to the buffer
        """
        self.args = args
        self.num_processes = num_processes
        self.action_dim = action_dim
        self.obs_dim = state_dim
        self.latent_dim = latent_dim
        self.batch_size = batch_size

        self.vae_buffer_add_thresh = vae_buffer_add_thresh
        self.max_buffer_size = int(max_buffer_size)  # interpret as total capacity (in transitions)
        self.buffer_len = 0
        self.insert_idx = 0

        # create storage buffers on CPU
        self.prev_state = torch.zeros((self.max_buffer_size, state_dim))
        self.next_state = torch.zeros((self.max_buffer_size, state_dim))
        self.actions = torch.zeros((self.max_buffer_size, action_dim))
        self.latent_samples = torch.zeros((self.max_buffer_size, latent_dim))
        self.latent_means = torch.zeros((self.max_buffer_size, latent_dim))
        self.latent_logvars = torch.zeros((self.max_buffer_size, latent_dim))
        self.rewards = torch.zeros((self.max_buffer_size, 1))
        self.dones = torch.zeros((self.max_buffer_size, 1))
 
    @property
    def buffer(self):
        """
        Convert the filled portion of the replay buffer into a list of tuples:

            (state, action, next_state, reward, latent_sample)

        where each element is **cpu numpy** for easy pickling.
        Only the first `self.buffer_len` entries are returned, so the list
        size grows with the buffer.
        """
        buf = []
        for i in range(self.buffer_len):
            buf.append(
                (
                    self.prev_state[i].cpu().numpy(),
                    self.actions[i].cpu().numpy(),
                    self.next_state[i].cpu().numpy(),
                    self.rewards[i].cpu().numpy(),
                    self.latent_samples[i].cpu().numpy(),  
                    self.latent_means[i].cpu().numpy(),    # μ  (style)
                    self.latent_logvars[i].cpu().numpy(),  # log σ² (style)

                )
            )
        return buf


    def to_device(self):
        self.prev_state = self.prev_state.to(get_device())
        self.next_state = self.next_state.to(get_device())
        self.actions = self.actions.to(get_device())
        self.latent_samples = self.latent_samples.to(get_device())
        self.latent_means = self.latent_means.to(get_device())
        self.latent_logvars = self.latent_logvars.to(get_device())
        self.rewards = self.rewards.to(get_device())
        self.dones = self.dones.to(get_device())
    
    def insert(self, prev_state, action, next_state, reward, latent_sample, latent_mean, latent_logvar, done):
        """
        Insert a batch of transitions into the replay buffer. 
        `prev_state`, `action`, `next_state`, `reward`, `done`, and `latent` should be torch Tensors or arrays
        with shape [num_processes, ...].
        """
        # Convert to CPU tensors if not already
        #prev_state = prev_state.cpu() if prev_state.is_cuda else prev_state
        #next_state = next_state.cpu() if next_state.is_cuda else next_state
        #action = action.cpu() if action.is_cuda else action
        #reward = reward.cpu() if reward.is_cuda else reward
        #done = done.cpu() if done.is_cuda else done
        #latent = latent.cpu() if latent.is_cuda else latent

        # We add transitions individually
        for i in range(self.num_processes):
            if np.random.uniform() < self.vae_buffer_add_thresh:
                self.prev_state[self.insert_idx].copy_(prev_state[i])
                self.next_state[self.insert_idx].copy_(next_state[i])
                self.actions[self.insert_idx].copy_(action[i])
                self.rewards[self.insert_idx].copy_(reward[i])
                self.dones[self.insert_idx].copy_(done[i])
                if latent_sample is not None:
                    self.latent_samples[self.insert_idx].copy_(latent_sample[i])
                if latent_mean is not None:
                    self.latent_means[self.insert_idx].copy_(latent_mean[i])
                if latent_logvar is not None:
                    self.latent_logvars[self.insert_idx].copy_(latent_logvar[i])

                self.insert_idx += 1
                if self.insert_idx >= self.max_buffer_size:
                    self.insert_idx = 0
                self.buffer_len = min(self.buffer_len + 1, self.max_buffer_size)

    def before_update(self, policy):
        pass
    
    def ready_for_update(self):
        return self.buffer_len > self.batch_size 

    def after_update(self):
        pass
    
    def __len__(self):
        return self.buffer_len

    def get_batch(self, batchsize=None, replace=False):
        """
        Sample a batch of transitions. Returns (prev_obs, next_obs, actions, rewards, dones, tasks).
        Tasks can be None if no tasks are stored.
        """
        if self.buffer_len == 0:
            raise ValueError("Replay buffer is empty!")
        
        if batchsize is None:
            batchsize = self.batch_size
        
        if self.buffer_len < batchsize:
            replace = True # just allow replacement if buffer size is smaller than batch size

        # If buffer is not full, only sample from [0, buffer_len)
        max_index = self.buffer_len
        indices = np.random.choice(max_index, batchsize, replace=replace)

        prev_obs = self.prev_state[indices].to(get_device())
        next_obs = self.next_state[indices].to(get_device())
        actions = self.actions[indices].to(get_device())
        rewards = self.rewards[indices].to(get_device())
        dones = self.dones[indices].to(get_device())
        latent_samples = self.latent_samples[indices].to(get_device())
        latent_means = self.latent_means[indices].to(get_device())
        latent_logvars = self.latent_logvars[indices].to(get_device())

        return prev_obs, next_obs, actions, rewards, latent_samples, latent_means, latent_logvars, dones
    
    def clear(self):
        self.prev_state *= 0
        self.next_state *= 0
        self.actions *= 0
        self.rewards *= 0
        self.dones *= 0
        self.latent_samples *= 0
        self.latent_means *= 0
        self.latent_logvars *= 0
        self.buffer_len = 0
        self.insert_idx = 0