# Copyright 2022 The OFA-Sys Team. 
# All rights reserved.
# This source code is licensed under the Apache 2.0 license 
# found in the LICENSE file in the root directory.

import os
import torch
import pickle


class FileDataset:
    def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
        self.file_path = file_path
        assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)

        self.separator = separator
        if selected_col_ids is None:
            # default to all fields
            self.selected_col_ids = list(
                range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
        else:
            self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
        if dtypes is None:
            # default to str
            self.dtypes = [str for col_id in self.selected_col_ids]
        else:
            self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
            assert len(self.dtypes) == len(self.selected_col_ids)

        self.data_cnt = 0
        try:
            self.slice_id = torch.distributed.get_rank()
            self.slice_count = torch.distributed.get_world_size()
        except Exception:
            self.slice_id = 0
            self.slice_count = 1
        self.cached_index = cached_index
        self._init_seek_index()
        self._reader = self._get_reader()
        print("file {} slice_id {} row count {} total row count {}".format(
            self.file_path, self.slice_id, self.row_count, self.total_row_count)
        )

    def _init_seek_index(self):
        if self.cached_index:
            cache_path = "{}.index".format(self.file_path)
            assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
            self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
            print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
                self.file_path, self.slice_id))
        else:
            # make an iteration over the file to get row_count and line_idx-to-offset mapping
            fp = open(self.file_path, "r")
            # for index out of size issue, ref: https://github.com/OFA-Sys/OFA/issues/91
            # fp = open(self.file_path, "rb")
            print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
                self.file_path, self.slice_id))
            self.total_row_count = 0
            offset = 0
            self.lineid_to_offset = []
            for line in fp:
                self.lineid_to_offset.append(offset)
                self.total_row_count += 1
                offset += len(line.encode('utf-8'))
                # for index out of size issue, ref: https://github.com/OFA-Sys/OFA/issues/91
                # offset += len(line)
        self._compute_start_pos_and_row_count()
        print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
            self.file_path, self.slice_id))

    def _compute_start_pos_and_row_count(self):
        self.row_count = self.total_row_count // self.slice_count
        if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
            self.row_count += 1
            self.start_pos = self.row_count * self.slice_id
        else:
            self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)

    def _get_reader(self):
        fp = open(self.file_path, "r")
        fp.seek(self.lineid_to_offset[self.start_pos])
        return fp

    def _seek(self, offset=0):
        try:
            print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
            self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
            self.data_cnt = offset
        except Exception:
            print("slice_id {} seek offset {}".format(self.slice_id, offset))
            self._reader.seek(self.lineid_to_offset[offset])
            self.data_cnt = offset

    def __del__(self):
        self._reader.close()

    def __len__(self):
        return self.row_count

    def get_total_row_count(self):
        return self.total_row_count

    def __getitem__(self, index):
        if self.data_cnt == self.row_count:
            print("reach the end of datafile, start a new reader")
            self.data_cnt = 0
            self._reader = self._get_reader()
        column_l = self._reader.readline().rstrip("\n").split(self.separator)
        self.data_cnt += 1
        column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
        return column_l
