

from verl.utils.hdfs_io import makedirs, copy
import torch
import os
import ray
import io
import tempfile

from collections import deque

remote_copy = ray.remote(copy)


@ray.remote
def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose):
    filename = name + '.pth'
    with tempfile.TemporaryDirectory() as tmpdirname:
        local_filepath = os.path.join(tmpdirname, filename)
        with open(local_filepath, 'wb') as f:
            f.write(data.getbuffer())


        if verbose:
            print(f'Saving {local_filepath} to {hdfs_dir}')
        try:
            copy(local_filepath, hdfs_dir)
        except Exception as e:
            print(e)


@ray.remote
class TrajectoryTracker():

    def __init__(self, hdfs_dir, verbose) -> None:
        self.hdfs_dir = hdfs_dir
        makedirs(hdfs_dir)
        self.verbose = verbose

        self.handle = deque()

    def dump(self, data: io.BytesIO, name):

        self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose))

    def wait_for_hdfs(self):
        while len(self.handle) != 0:
            future = self.handle.popleft()
            ray.get(future)


def dump_data(data, name):
    enable = os.getenv('VERL_ENABLE_TRACKER', '0') == '1'
    if not enable:
        return
    buffer = io.BytesIO()
    torch.save(data, buffer)
    tracker = get_trajectory_tracker()
    ray.get(tracker.dump.remote(buffer, name))


def get_trajectory_tracker():
    hdfs_dir = os.getenv('VERL_TRACKER_HDFS_DIR', default=None)
    verbose = os.getenv('VERL_TRACKER_VERBOSE', default='0') == '1'
    assert hdfs_dir is not None
    tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True,
                                        lifetime="detached").remote(hdfs_dir, verbose)
    return tracker


if __name__ == '__main__':

    os.environ['VERL_ENABLE_TRACKER'] = '1'
    os.environ['VERL_TRACKER_HDFS_DIR'] = '~/debug/test'

    @ray.remote
    def process(iter):
        data = {'obs': torch.randn(10, 20)}
        dump_data(data, f'process_{iter}_obs')

    ray.init()

    output_lst = []

    for i in range(10):
        output_lst.append(process.remote(i))

    out = ray.get(output_lst)

    tracker = get_trajectory_tracker()
    ray.get(tracker.wait_for_hdfs.remote())
