import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler


class Buffer():
	"""
	Replay buffer for TD-MPC2 training. Based on torchrl.
	Uses CUDA memory if available, and CPU memory otherwise.
	"""

	def __init__(self, cfg):
		self.cfg = cfg
		self._device = torch.device('cuda:0')
		self._capacity = min(cfg.buffer_size, cfg.steps)
		self._sampler = SliceSampler(
			num_slices=self.cfg.batch_size,
			end_key=None,
			traj_key='episode',
			truncated_key=None,
			strict_length=True,
			cache_values=cfg.multitask,
		)
		self._batch_size = cfg.batch_size * (cfg.horizon+1)
		self._num_eps = 0
		self.maybe_load_checkpoint()

	@property
	def capacity(self):
		"""Return the capacity of the buffer."""
		return self._capacity

	@property
	def num_eps(self):
		"""Return the number of episodes in the buffer."""
		return self._num_eps

	def _reserve_buffer(self, storage):
		"""
		Reserve a buffer with the given storage.
		"""
		return ReplayBuffer(
			storage=storage,
			sampler=self._sampler,
			pin_memory=False,
			prefetch=self.cfg.buffer_prefetch,
			batch_size=self._batch_size,
		)

	def _init(self, tds):
		"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
		print(f'Buffer capacity: {self._capacity:,}')
		mem_free, _ = torch.cuda.mem_get_info()
		bytes_per_step = sum([
				(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
				else sum([x.numel()*x.element_size() for x in v.values()])) \
			for v in tds.values()
		]) / len(tds)
		total_bytes = bytes_per_step*self._capacity
		print(f'Storage required: {total_bytes/1e9:.2f} GB')
		# Heuristic: decide whether to use CUDA or CPU memory
		storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu'
		print(f'Using {storage_device.upper()} memory for storage.')
		self._storage_device = torch.device(storage_device)
		return self._reserve_buffer(LazyTensorStorage(self._capacity, device=self._storage_device))

	def save_checkpoint(self):
		if self._num_eps == 0:
			print('Buffer is empty. No checkpoint saved.')
			return
		# save sample instance for future reconstruction
		if not (self.cfg.work_dir / 'sample_buffer_td').exists():
			sample = self._buffer.sample()
			sample.dumps(self.cfg.work_dir / 'sample_buffer_td')
		# save my buffer class info
		d = {
			'_num_eps': self._num_eps,
			'_capacity': self._capacity,
		}
		torch.save(d, self.cfg.work_dir / 'buffer_info.pth')
		# save the underlying buffer
		self._buffer.save(self.cfg.work_dir)

	def delete_checkpoint(self):
		# delete the storage
		buffer_path = self.cfg.work_dir / 'storage'
		if buffer_path.exists():
			for file in buffer_path.iterdir():
				file.unlink()
				print(f'Deleted buffer file: {file}')

	def maybe_load_checkpoint(self):
		buffer_info_path = self.cfg.work_dir / 'buffer_info.pth'
		if not buffer_info_path.exists():
			print('No buffer checkpoint info found. Initializing new buffer.')
			return
		d = torch.load(buffer_info_path, map_location='cpu')
		self._num_eps = d['_num_eps']
		self._capacity = d['_capacity']
		print(f'Loaded buffer info: {self._num_eps} episodes, capacity {self._capacity:,}')
		tds = TensorDict.load(self.cfg.work_dir / 'sample_buffer_td', map_location='cpu')
		self._init(tds)  # should setup self._buffer
		self._buffer.load(self.cfg.work_dir)
		print(f'Buffer checkpoint loaded successfully of size {len(self._buffer.storage)}.')

	def load(self, td):
		"""
		Load a batch of episodes into the buffer. This is useful for loading data from disk,
		and is more efficient than adding episodes one by one.
		"""
		num_new_eps = len(td)
		episode_idx = torch.arange(self._num_eps, self._num_eps+num_new_eps, dtype=torch.int64)
		td['episode'] = episode_idx.unsqueeze(-1).expand(-1, td['reward'].shape[1])
		if self._num_eps == 0:
			self._buffer = self._init(td[0])
		td = td.reshape(td.shape[0]*td.shape[1])
		self._buffer.extend(td)
		self._num_eps += num_new_eps
		return self._num_eps

	def add(self, td):
		"""Add an episode to the buffer."""
		td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
		if self._num_eps == 0:
			self._buffer = self._init(td)
		self._buffer.extend(td)
		self._num_eps += 1
		return self._num_eps

	def _prepare_batch(self, td):
		"""
		Prepare a sampled batch for training (post-processing).
		Expects `td` to be a TensorDict with batch size TxB.
		"""
		td = td.select("obs", "action", "reward", "terminated", "task", strict=False).to(self._device, non_blocking=True)
		obs = td.get('obs').contiguous()
		action = td.get('action')[1:].contiguous()
		reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
		terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
		task = td.get('task', None)
		if task is not None:
			task = task[0].contiguous()
		return obs, action, reward, terminated, task

	def sample(self):
		"""Sample a batch of subsequences from the buffer."""
		td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
		return self._prepare_batch(td)
