import torch
import requests
import os
import pickle
import numpy as np


class FileDataset(object):
    def __init__(
        self,
        file_path,
        selected_col_ids=None,
        dtypes=None,
        separator="\t",
        cached_index=False,
        data_slice=True,
    ):
        self.file_path = file_path
        assert os.path.exists(
            self.file_path
        ), "Error: The local datafile {} not exists!".format(self.file_path)

        self.separator = separator
        if selected_col_ids is None:
            # default to all fields
            self.selected_col_ids = list(
                range(
                    len(
                        open(self.file_path)
                        .readline()
                        .rstrip("\n")
                        .split(self.separator)
                    )
                )
            )
        else:
            self.selected_col_ids = [
                int(col_id) for col_id in selected_col_ids.split(",")
            ]
        if dtypes is None:
            # default to str
            self.dtypes = [str for col_id in self.selected_col_ids]
        else:
            self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
            assert len(self.dtypes) == len(self.selected_col_ids)

        self.data_cnt = 0
        if data_slice:
            try:
                self.slice_id = torch.distributed.get_rank()
                self.slice_count = torch.distributed.get_world_size()
            except Exception:
                self.slice_id = 0
                self.slice_count = 1
        else:
            self.slice_id = 0
            self.slice_count = 1
        self.cached_index = cached_index
        self._init_seek_index()
        self._reader = self._get_reader()
        print(
            "file {} slice_id {} row count {} total row count {}".format(
                self.file_path, self.slice_id, self.row_count, self.total_row_count
            )
        )

    def _init_seek_index(self):
        if self.cached_index:
            cache_path = "{}.index".format(self.file_path)
            assert os.path.exists(cache_path), "cache file {} not exists!".format(
                cache_path
            )
            self.total_row_count, self.lineid_to_offset = pickle.load(
                open(cache_path, "rb")
            )
            print(
                "local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
                    self.file_path, self.slice_id
                )
            )
        else:
            # make an iteration over the file to get row_count and line_idx-to-offset mapping
            fp = open(self.file_path, "r")
            print(
                "local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
                    self.file_path, self.slice_id
                )
            )
            self.total_row_count = 0
            offset = 0
            self.lineid_to_offset = []
            for line in fp:
                self.lineid_to_offset.append(offset)
                self.total_row_count += 1
                offset += len(line.encode("utf-8"))
        self._compute_start_pos_and_row_count()
        print(
            "local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
                self.file_path, self.slice_id
            )
        )

    def _compute_start_pos_and_row_count(self):
        self.row_count = self.total_row_count // self.slice_count
        if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
            self.row_count += 1
            self.start_pos = self.row_count * self.slice_id
        else:
            self.start_pos = self.row_count * self.slice_id + (
                self.total_row_count - self.row_count * self.slice_count
            )

    def _get_reader(self):
        fp = open(self.file_path, "r")
        fp.seek(self.lineid_to_offset[self.start_pos])
        return fp

    def _seek(self, offset=0):
        try:
            print(
                "slice_id {} seek offset {}".format(
                    self.slice_id, self.start_pos + offset
                )
            )
            self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
            self.data_cnt = offset
        except Exception:
            print("slice_id {} seek offset {}".format(self.slice_id, offset))
            self._reader.seek(self.lineid_to_offset[offset])
            self.data_cnt = offset

    def __del__(self):
        self._reader.close()

    def __len__(self):
        return self.total_row_count

    def get_total_row_count(self):
        return self.total_row_count

    def __getitem__(self, index):
        if self.data_cnt == self.row_count:
            print("reach the end of datafile, start a new reader")
            self.data_cnt = 0
            self._reader = self._get_reader()
        cur_line = self._reader.readline()
        column_l = cur_line.rstrip("\n").split(self.separator)
        self.data_cnt += 1
        try:
            column_l = [
                dtype(column_l[col_id])
                for col_id, dtype in zip(self.selected_col_ids, self.dtypes)
            ]
        except:
            import pdb

            pdb.sset_trace()
        return column_l


def collate_fn(samples, pad_idx, eos_idx):
    if len(samples) == 0:
        return {}

    def merge(key, pad_idx, pading_size=None):
        res = collate_tokens(
            [s[key] for s in samples],
            pad_idx,
            eos_idx=eos_idx,
            pad_to_length=pading_size,
        )
        return res

    larger_size = max([s["source"].size(0) for s in samples])

    id = np.array([s["id"] for s in samples])
    src_tokens = merge("source", pad_idx=pad_idx, pading_size=larger_size)
    src_tokens_masks = merge("text_mask", pad_idx=0, pading_size=larger_size)

    batch = {
        "id": id,
        "nsentences": len(samples),
        "net_input": {
            "input_ids": src_tokens,
            "attention_masks": src_tokens_masks,
        },
    }
    if samples[0].get("features", None) is not None:
        batch["net_input"]["features"] = torch.stack(
            [sample["features"] for sample in samples], dim=0
        )
    if samples[0].get("med_patch_image", None) is not None:
        batch["net_input"]["med_patch_images"] = torch.stack(
            [sample["med_patch_image"] for sample in samples], dim=0
        )
    if samples[0].get("patch_image", None) is not None:
        batch["net_input"]["patch_images"] = torch.stack(
            [sample["patch_image"] for sample in samples], dim=0
        )
    if samples[0].get("patch_mask", None) is not None:
        batch["net_input"]["patch_masks"] = torch.cat(
            [sample["patch_mask"] for sample in samples]
        )
    # image generation
    if samples[0].get("code_mask", None) is not None:
        batch["net_input"]["code_masks"] = torch.cat(
            [sample["code_mask"] for sample in samples]
        )
    if samples[0].get("code_image", None) is not None:
        batch["code_images"] = torch.cat([sample["code_image"] for sample in samples])
    # For classification tasks (i.e., VQA, SNLI-VE, GLUE)
    if samples[0].get("conf", None) is not None:
        batch["conf"] = torch.cat([s["conf"] for s in samples], dim=0)
    if samples[0].get("ref_dict", None) is not None:
        batch["ref_dict"] = np.array([s["ref_dict"] for s in samples])
    if samples[0].get("constraint_mask", None) is not None:
        batch["constraint_masks"] = merge("constraint_mask")
    if samples[0].get("decoder_prompt", None) is not None:
        batch["decoder_prompts"] = np.array(
            [s["decoder_prompt"].tolist() for s in samples]
        )
    # For detection and visual grounding
    if samples[0].get("w_resize_ratio", None) is not None:
        batch["w_resize_ratios"] = torch.stack(
            [s["w_resize_ratio"] for s in samples], dim=0
        )
    if samples[0].get("h_resize_ratio", None) is not None:
        batch["h_resize_ratios"] = torch.stack(
            [s["h_resize_ratio"] for s in samples], dim=0
        )
    if samples[0].get("region_coord", None) is not None:
        batch["region_coords"] = torch.stack(
            [s["region_coord"] for s in samples], dim=0
        )

    return batch


def collate_tokens(
    values,
    pad_idx,
    eos_idx=None,
    left_pad=False,
    move_eos_to_beginning=False,
    pad_to_length=None,
    pad_to_multiple=1,
    pad_to_bsz=None,
):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values)
    size = size if pad_to_length is None else max(size, pad_to_length)
    if pad_to_multiple != 1 and size % pad_to_multiple != 0:
        size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)

    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        if move_eos_to_beginning:
            if eos_idx is None:
                # if no eos_idx is specified, then use the last token in src
                dst[0] = src[-1]
            else:
                dst[0] = eos_idx
            dst[1:] = src[:-1]
        else:
            dst.copy_(src)

    if values[0].dim() == 1:
        res = values[0].new(len(values), size).fill_(pad_idx)
    elif values[0].dim() == 2:
        assert move_eos_to_beginning is False
        res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
    else:
        raise NotImplementedError

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
    return res


def pad_or_cut_img_tensors(img_tensors, img_size, num_imgs):
    if len(img_tensors) < num_imgs:
        zero_padding = torch.zeros((
            num_imgs-len(img_tensors),
            3,
            img_size,
            img_size
        ), dtype=torch.float)
        img_tensors = torch.cat((img_tensors, zero_padding), dim=0)
    elif len(img_tensors) > num_imgs:
        img_tensors = img_tensors[:num_imgs, :, :, :]
    return img_tensors

