import os
import copy

import torch
import torch.multiprocessing as mp

from prover.utils import load_jsonl_objects


class DataLoader(object):
    def __init__(self, data_path, data_split, data_repeat, node_rank, world_size, log_dir):
        self.manager = mp.Manager()
        self.queue = self.manager.Queue()
        self.lock = mp.Lock()
        self.finished_flag_filename = 'finished_running.txt'

        done_set = set()
        for dirname in os.listdir(log_dir):
            run_dir = os.path.join(log_dir, dirname)
            if os.path.isdir(run_dir):
                for subdirname in os.listdir(run_dir):
                    if subdirname.startswith('run') and os.path.exists(os.path.join(run_dir, subdirname, self.finished_flag_filename)):
                        done_set.add(os.path.join(dirname, subdirname))

        todo_count = 0
        if isinstance(data_split, str):
            data_split = [data_split]
        dataset = load_jsonl_objects(data_path)
        for _repeat in range(data_repeat):
            for prob_idx, prob in enumerate(dataset):
                prob_runname = os.path.join(prob['name'], f'run{_repeat}')
                if f'{prob_idx}_{prob_runname}' in done_set:
                    continue
                if data_split is not None and prob['split'] not in data_split:
                    continue
                if todo_count % world_size == node_rank:
                    self.queue.put((prob_idx, prob_runname, copy.deepcopy(prob)))
                todo_count += 1
        print('Number of TODO Problems: {}'.format(self.queue.qsize()))
    
    def size(self):
        return self.queue.qsize()
    
    def get(self):
        with self.lock:
            if self.queue.qsize() > 0:
                return self.queue.get()
        return None, None, None
