import json
from torch.utils.data import Dataset
from tqdm import tqdm
from joblib import Parallel, delayed
import json
import glob
import gzip
import hashlib
import os
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import tqdm
from joblib import Parallel, delayed
from IPython import embed



class NHIRDDataset(Dataset):
    def __init__(self, input_path, n_jobs=10):
        self.all_lines = self.preprocessing(input_path, n_jobs)

    def batch(self, path_list, n_jobs):
        l = len(path_list)
        bs = int(l / n_jobs)
        for ndx in range(0, l, bs):
            yield path_list[ndx:min(ndx + bs, l)]        

    def get_meta_cache(self, file_paths,input_path):
        h = hashlib.new("sha256")
        h.update("-".join(file_paths).encode())
        cache_name = h.hexdigest()
        cache_path = os.path.join(input_path, ".cache", cache_name)
        try:
            return json.load(open(cache_path))
        except:
            cache_dir = os.path.join(input_path, ".cache")
            Path(cache_dir).mkdir(exist_ok=True)
            return cache_path

    def preprocessing(self, input_path, n_jobs):
        file_paths = sorted(glob.glob(input_path))
        cache = self.get_meta_cache(file_paths,input_path[:-1])
        if type(cache) == str:
            file_path_batches = list(self.batch(file_paths, n_jobs))
            all_line_offsets = Parallel(n_jobs=n_jobs)(delayed(self.get_line_positions)(i, file_path_batch) for i, file_path_batch in enumerate(file_path_batches))
            all_line_offsets = dict([line_offset for line_offsets in all_line_offsets for line_offset in line_offsets])
            all_lines = []
            for file_path in file_paths:
                for offset in all_line_offsets[file_path]:
                    all_lines.append((file_path, offset))
            json.dump(all_lines, open(cache, "w"))
            return all_lines
        else:
            return cache

    def get_line_positions(self, i, paths):
        line_positions = []
        for path in tqdm(paths, position=i):
            line_offsets = []
            offset = 0
            with gzip.open(path, "rt") as f:
                for line in f:
                    line_offsets.append(offset)
                    offset += len(line)
            line_positions.append((path, line_offsets))     
        return line_positions

    def __getitem__(self, idx):
        path, offset = self.all_lines[idx]
        with gzip.open(path, "rt") as f:
            f.seek(offset)
            data = json.loads(f.readline())
        return data
    def __len__(self):
        return len(self.all_lines)
