import collections
import glob
import logging
import os
IF_PREFER_LOSS = int(os.environ.get('IF_PREFER_LOSS', 0))
SCORE_RATIO = float(os.environ.get('SCORE_RATIO', 1))
assert 0<=SCORE_RATIO<=1
import random
from typing import Dict, List, Tuple

import hydra
import jsonlines
import numpy as np
import torch
from omegaconf import DictConfig
from torch import Tensor as T
from dpr.data.tables import Table
from dpr.utils.data_utils import read_data_from_json_files, Tensorizer
from dpr.utils.tasks import (
    task_map,
    get_prompt_files,
)

logger = logging.getLogger(__name__)
BiEncoderPassage = collections.namedtuple(
    "BiEncoderPassage", ["text", "title", "meta_data"]
)


class BiEncoderSample(object):
    query: str
    positive_passages: List[BiEncoderPassage]
    negative_passages: List[BiEncoderPassage]
    hard_negative_passages: List[BiEncoderPassage]


class RepTokenSelector(object):
    def get_positions(self, input_ids: T, tenzorizer: Tensorizer):
        raise NotImplementedError


class RepStaticPosTokenSelector(RepTokenSelector):
    def __init__(self, static_position: int = 0):
        self.static_position = static_position

    def get_positions(self, input_ids: T, tenzorizer: Tensorizer):
        return self.static_position


class RepSpecificTokenSelector(RepTokenSelector):
    def __init__(self, token: str = "[CLS]"):
        self.token = token
        self.token_id = None

    def get_positions(self, input_ids: T, tenzorizer: Tensorizer):
        if not self.token_id:
            self.token_id = tenzorizer.get_token_id(self.token)
        token_indexes = (input_ids == self.token_id).nonzero()
        # check if all samples in input_ids has index presence and out a default value otherwise
        bsz = input_ids.size(0)
        if bsz == token_indexes.size(0):
            return token_indexes

        token_indexes_result = []
        found_idx_cnt = 0
        for i in range(bsz):
            if (
                found_idx_cnt < token_indexes.size(0)
                and token_indexes[found_idx_cnt][0] == i
            ):
                # this samples has the special token
                token_indexes_result.append(token_indexes[found_idx_cnt])
                found_idx_cnt += 1
            else:
                logger.warning("missing special token %s", input_ids[i])

                token_indexes_result.append(
                    torch.tensor([i, 0]).to(input_ids.device)
                )  # setting 0-th token, i.e. CLS for BERT as the special one
        token_indexes_result = torch.stack(token_indexes_result, dim=0)
        return token_indexes_result


DEFAULT_SELECTOR = RepStaticPosTokenSelector()


class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        selector: DictConfig = None,
        special_token: str = None,
        shuffle_positives: bool = False,
        query_special_suffix: str = None,
        encoder_type: str = None,
    ):
        if selector:
            self.selector = hydra.utils.instantiate(selector)
        else:
            self.selector = DEFAULT_SELECTOR
        self.special_token = special_token
        self.encoder_type = encoder_type
        self.shuffle_positives = shuffle_positives
        self.query_special_suffix = query_special_suffix

    def load_data(self):
        raise NotImplementedError

    def __getitem__(self, index) -> BiEncoderSample:
        raise NotImplementedError

    def _process_query(self, query: str):
        # as of now, always normalize query
        query = normalize_question(query)
        if self.query_special_suffix and not query.endswith(self.query_special_suffix):
            query += self.query_special_suffix

        return query


def get_dpr_files(sources) -> List[str]:
    if isinstance(sources, str):
        sources = [sources] 
    res = []
    for source_name in sources:
        if os.path.exists(source_name) or glob.glob(source_name):
            res.extend(glob.glob(source_name))
        else:
            # try to use data downloader
            from dpr.data.download_data import download

            res.extend(download(source_name))
    logger.info("Toal files num %d" % len(res))
    return res


def reformat(text):
    return " ".join([f"{i+1}#) {x.strip()}" for i, x in enumerate(text.split(";"))])


import re
def remove_double_space(string):
    return re.sub("[ ]{2,}", " ", string)

def create_dict(raw_data: List[dict]) -> dict:
    TaskandId_to_data = {}
    for data in raw_data:
        task = data["task_name"]
        id = str(data["id"])
        TaskandId_to_data[task+id] = data
    return TaskandId_to_data

class Se2Dataset(Dataset):
    def __init__(
        self,
        file: str,
        top_k,
        loss_type: str = "dpr",
        train_clusters: str = None,
        multi_task: bool = False,
        split: str = None,
        hard_neg=False,
        selector: DictConfig = None,
        special_token: str = None,
        encoder_type: str = None,
        shuffle_positives: bool = False,
        query_special_suffix: str = None,
        prompt_pool_path: str = None,
        prompt_setup_type: str = "q",
        task_setup_type: str = "q",
    ):
        super().__init__(
            selector,
            special_token=special_token,
            encoder_type=encoder_type,
            shuffle_positives=shuffle_positives,
            query_special_suffix=query_special_suffix,
        )
        assert loss_type in ['dpr']
        logger.info("loss_type: %s", loss_type)

        self.split = split
        self.top_k = top_k
        self.file = file
        self.hard_neg = hard_neg
        self.data = []
        self.loss_type = loss_type

        if train_clusters is not None:
            prompt_pool_path = get_prompt_files(prompt_pool_path, train_clusters)
        logger.info("prompt files: %s", prompt_pool_path)
        self.prompt_pool = read_data_from_json_files(prompt_pool_path)
        logger.info("prompt passages num : %d", len(self.prompt_pool))
        self.TaskandId_to_data = create_dict(self.prompt_pool)
        self.train_clusters = train_clusters
        self.multi_task = multi_task
        self.prompt_setup_type = prompt_setup_type
        self.task_setup_type = task_setup_type

    def format_example(self, entry, setup_type):
        task = task_map.cls_dic[entry["task_name"]]()
        if setup_type == "qa":
            sent = (
                task.get_question(entry)
                + task.get_answer(entry)
            )
        elif setup_type == "q":
            sent = task.get_question(entry)
        elif setup_type == "a":
            sent = task.get_answer(entry).strip()
        return remove_double_space(sent)
    
    def get_entry(self, entry):
        entry["ctxs"] = sorted(
            random.sample(entry["ctxs"],k=round(len(entry["ctxs"])*SCORE_RATIO)),
            key=lambda x:x['loss']
        )
        
        if IF_PREFER_LOSS==0:
            return self.get_old_entry(entry)
        elif IF_PREFER_LOSS==1:
            return self.get_multi_prefer_entry(entry,bottom_sample=True,positive_chosen=False)
        elif IF_PREFER_LOSS==-1:
            return self.get_udr_rank_entry(entry)
        elif IF_PREFER_LOSS==2:
            return self.get_multi_prefer_entry(entry,bottom_sample=True)
        elif IF_PREFER_LOSS==3:
            return self.get_multi_prefer_entry(entry,bottom_sample=False)
        elif IF_PREFER_LOSS==4:
            return self.get_top_prefer_entry(entry,bottom_sample=True)
        elif IF_PREFER_LOSS==5:
            return self.get_top_prefer_entry(entry,bottom_sample=False)
        elif IF_PREFER_LOSS==6:
            if self.split == "valid":
                return self.get_old_entry(entry)
            else:
                return self.get_multi_prefer_entry(entry,bottom_sample=True)
        else:
            raise NotImplementedError

    def get_multi_prefer_entry(self, entry, bottom_sample=True, positive_chosen=True):
        if not self.loss_type == 'dpr':
            raise NotImplementedError
        if positive_chosen:
            true_cntxs = [
                ctx_entry
                for ctx_entry in entry["ctxs"]
                if (ctx_entry["one_shot_acc"] == True or ctx_entry["one_shot_acc"] > 0)
            ]
        else:
            true_cntxs=entry["ctxs"]
        num_pair = min([len(true_cntxs),self.top_k,len(entry["ctxs"])//2])
        true_cntxs = true_cntxs[:num_pair]
        if bottom_sample:
            false_cntxs = [
                ctx_entry
                for ctx_entry in entry["ctxs"][-num_pair:]
            ]
        else:
            true_cntxs_ids = random.sample(range(0,len(entry["ctxs"])-1),k=num_pair)
            false_cntxs_ids = [random.randint(a+1, len(entry["ctxs"])-1) for a in true_cntxs_ids]
            true_cntxs = [entry["ctxs"][a] for a in true_cntxs_ids]
            false_cntxs = [entry["ctxs"][a] for a in false_cntxs_ids]
            # false_cntxs = [
            #     ctx_entry
            #     for ctx_entry in random.sample(entry["ctxs"][num_pair:],k=num_pair)
            # ]
        random.shuffle(true_cntxs)
        random.shuffle(false_cntxs)

        datas = []
        for p_example,n_example in zip(true_cntxs,false_cntxs):
            hard_negative_ctxs = [{"demonstration": self.format_example(n_example, self.prompt_setup_type)}]
            positive_cntx = [{"demonstration": self.format_example(p_example, self.prompt_setup_type)}]

            negative_cntx = []
            question = ""
            have_choosen = entry["choosen"]
            for id in have_choosen:
                example = self.prompt_pool[id]
                question += self.format_example(example, self.prompt_setup_type) + " \n "
            
            question += self.format_example(entry, self.task_setup_type)
            item = {
                "question": question,
                "answers": [],
                "positive_ctxs": positive_cntx,
                "negative_ctxs": negative_cntx,
            }
            if self.hard_neg:
                item["hard_negative_ctxs"] = hard_negative_ctxs
            datas.append(item)
        return datas

    def get_top_prefer_entry(self, entry, bottom_sample=True):
        if not self.loss_type == 'dpr':
            raise NotImplementedError
        true_cntxs = [
            ctx_entry
            for ctx_entry in entry["ctxs"]
            if (ctx_entry["one_shot_acc"] == True or ctx_entry["one_shot_acc"] > 0)
        ]
        num_pair = self.top_k
        true_cntxs = (true_cntxs[:1])*num_pair
        if bottom_sample:
            false_cntxs = [
                ctx_entry
                for ctx_entry in entry["ctxs"][-num_pair :]
            ]
        else:
            false_cntxs = [
                ctx_entry
                for ctx_entry in random.sample(entry["ctxs"][1:],k=num_pair)
            ]

        datas = []
        for p_example,n_example in zip(true_cntxs,false_cntxs):
            hard_negative_ctxs = [{"demonstration": self.format_example(n_example, self.prompt_setup_type)}]
            positive_cntx = [{"demonstration": self.format_example(p_example, self.prompt_setup_type)}]

            negative_cntx = []
            question = ""
            have_choosen = entry["choosen"]
            for id in have_choosen:
                example = self.prompt_pool[id]
                question += self.format_example(example, self.prompt_setup_type) + " \n "
            
            question += self.format_example(entry, self.task_setup_type)
            item = {
                "question": question,
                "answers": [],
                "positive_ctxs": positive_cntx,
                "negative_ctxs": negative_cntx,
            }
            if self.hard_neg:
                item["hard_negative_ctxs"] = hard_negative_ctxs
            datas.append(item)
        return datas

    def get_udr_rank_entry(self, entry):
        if self.loss_type == 'dpr':
            true_cntxs = [
                ctx_entry
                for ctx_entry in entry["ctxs"]
                if (ctx_entry["one_shot_acc"] == True or ctx_entry["one_shot_acc"] > 0)
            ]
            false_cntxs = sorted(random.sample(entry["ctxs"],k=self.top_k), key=lambda x:x['loss'])
            positive_cntx = [
                {"demonstration": self.format_example(p_example, self.prompt_setup_type)}
                for p_example in true_cntxs[:1]
            ] 
            hard_negative_ctxs = []
            negative_cntx = [
                {"demonstration": self.format_example(n_example, self.prompt_setup_type)}
                for n_example in false_cntxs
            ]
            question = ""
            have_choosen = entry["choosen"]
            for id in have_choosen:
                example = self.prompt_pool[id]
                question += self.format_example(example, self.prompt_setup_type) + " \n "
            
            question += self.format_example(entry, self.task_setup_type)
            entry = {
                "question": question,
                "answers": [],
                "positive_ctxs": positive_cntx,
                "negative_ctxs": negative_cntx,
            }
            if self.hard_neg:
                entry["hard_negative_ctxs"] = hard_negative_ctxs
            return [entry]
        else:
            raise NotImplementedError

    def get_prefer_entry(self, entry):
        if self.loss_type == 'dpr':
            true_cntxs = [
                ctx_entry
                for ctx_entry in entry["ctxs"]
                if (ctx_entry["one_shot_acc"] == True or ctx_entry["one_shot_acc"] > 0)
            ]
            positive_cntx = [
                {"demonstration": self.format_example(p_example, self.prompt_setup_type)}
                for p_example in true_cntxs[:1]
            ] 
            hard_negative_ctxs = [
                {"demonstration": self.format_example(n_example, self.prompt_setup_type)}
                for n_example in entry["ctxs"][-self.top_k :]
            ]
            negative_cntx = []
            question = ""
            have_choosen = entry["choosen"]
            for id in have_choosen:
                example = self.prompt_pool[id]
                question += self.format_example(example, self.prompt_setup_type) + " \n "
            
            question += self.format_example(entry, self.task_setup_type)
            entry = {
                "question": question,
                "answers": [],
                "positive_ctxs": positive_cntx,
                "negative_ctxs": negative_cntx,
            }
            if self.hard_neg:
                entry["hard_negative_ctxs"] = hard_negative_ctxs
            return [entry]
        else:
            raise NotImplementedError


    def get_old_entry(self, entry):
        if self.loss_type == 'dpr':
            task_name = entry["task_name"]
            # positive
            true_cntxs = [
                ctx_entry
                for ctx_entry in entry["ctxs"]
                if (ctx_entry["one_shot_acc"] == True or ctx_entry["one_shot_acc"] > 0)
            ] # filter out those with acc == 0
            positive_cntx = [
                {"demonstration": self.format_example(p_example, self.prompt_setup_type)}
                for p_example in true_cntxs[:1]
            ]  # select the first-ranked prompt as the positive
            positive_ids = [ctx_entry['id'] for ctx_entry in true_cntxs[:1]] 

            # remember to ensure `topk` = `num_of_negatives` if you want the hard negatives to be the last k-ranked prompts
            hard_negative_ctxs = [
                {"demonstration": self.format_example(n_example, self.prompt_setup_type)}
                for n_example in entry["ctxs"][-self.top_k :]
            ]

            hard_negative_ids = [
                n_example['id']
                for n_example in entry["ctxs"][-self.top_k :]
            ]
            # negative
            if self.multi_task:
                # when multi_task == True,
                # random sample those of different tasks as negatives
                negative_cntx = [
                    {"demonstration": self.format_example(n_example, self.prompt_setup_type)}
                    for n_example in random.choices(self.prompt_pool, k=self.top_k)
                    if not n_example["task_name"] == task_name
                ] 
            else:
                # when multi_task == False, 
                # random sample negatives from the same training set, 
                # but avoid choosing those already existed the the positives/hard negatives
                filtered_prompt_pool = [prompt for prompt in self.prompt_pool if prompt['id'] not in positive_ids + hard_negative_ids]
                negative_cntx = [
                    {"demonstration": self.format_example(n_example, self.prompt_setup_type)}
                    for n_example in random.choices(filtered_prompt_pool, k=self.top_k)
                ]
            # use task_name + id to finde example，for multi task data
            question = ""
            have_choosen = entry["choosen"]
            for id in have_choosen:
                example = self.prompt_pool[id]
                question += self.format_example(example, self.prompt_setup_type) + " \n "
            
            question += self.format_example(entry, self.task_setup_type)

            entry = {
                "question": question,
                "answers": [],
                "positive_ctxs": positive_cntx,
                "negative_ctxs": negative_cntx,
            }
            if self.hard_neg:
                entry["hard_negative_ctxs"] = hard_negative_ctxs
            return [entry]
        else:
            raise NotImplementedError

    def load_data(self):
        assert self.split in ["train", "valid"]
        if self.train_clusters is not None:
            clusters = self.train_clusters.split('+')
            self.file = [
                f"{self.file}/{cluster}/*_scored_{self.split}.json"
                for cluster in clusters
            ]
        logger.info("cluster files: %s", self.file)

        self.data_files = get_dpr_files(self.file)
        logger.info("dpr files: %s", self.data_files)

        raw_data = read_data_from_json_files(self.data_files)
        logger.info("********len(raw_data): %d", len(raw_data))
        self.data = []
        for entry in raw_data:
            self.data.extend(self.get_entry(entry))
        # filter out those without positive ctx
        if self.loss_type == 'dpr':
            self.data = [r for r in self.data if len(r["positive_ctxs"]) > 0]
            logger.info("filter out data for : {}".format(len(raw_data) - len(self.data)))
            logger.info("Total filtered data size: {}".format(len(self.data)))
        else:
            print("loss type error")
        
    def __getitem__(self, index) -> BiEncoderSample:
        if self.loss_type == 'dpr':
            json_sample = self.data[index]
            r = BiEncoderSample()
            r.query = json_sample["question"]

            positive_ctxs = json_sample["positive_ctxs"]
            negative_ctxs = (
                json_sample["negative_ctxs"] if "negative_ctxs" in json_sample else []
            )
            hard_negative_ctxs = (
                json_sample["hard_negative_ctxs"]
                if "hard_negative_ctxs" in json_sample
                else []
            )

            for ctx in positive_ctxs + negative_ctxs + hard_negative_ctxs:
                if "title" not in ctx:
                    ctx["title"] = None

            def create_passage(ctx: dict):
                return BiEncoderPassage(
                    ctx["demonstration"], ctx["title"], None  # meta_data=None
                )

            r.positive_passages = [create_passage(ctx) for ctx in positive_ctxs]
            r.negative_passages = [create_passage(ctx) for ctx in negative_ctxs]
            r.hard_negative_passages = [create_passage(ctx) for ctx in hard_negative_ctxs]
            return r
        else:
            raise NotImplementedError

    def __len__(self):
        return len(self.data)

    def get_qas(self) -> Tuple[List[str], List[str]]:
        return [s["question"] for s in self.data], [s["answers"] for s in self.data]

    def get_qas_range(
        self, start_idx: int, end_idx: int
    ) -> Tuple[List[str], List[str]]:
        return (
            [s["question"] for s in self.data[start_idx:end_idx]],
            [s["answers"] for s in self.data[start_idx:end_idx]],
        )


class JsonQADataset(Dataset):
    def __init__(
        self,
        file: str,
        selector: DictConfig = None,
        special_token: str = None,
        encoder_type: str = None,
        shuffle_positives: bool = False,
        normalize: bool = False,
        query_special_suffix: str = None,
    ):
        super().__init__(
            selector,
            special_token=special_token,
            encoder_type=encoder_type,
            shuffle_positives=shuffle_positives,
            query_special_suffix=query_special_suffix,
        )
        # self.file = f"dpr/{file.replace('.','/')}"
        self.file = file
        self.data_files = []
        self.data = []
        self.normalize = normalize
        logger.info("Data files: %s", self.data_files)

    def load_data(self):
        print(self.file)
        self.data_files = get_dpr_files(self.file)
        print(self.data_files)

        data = read_data_from_json_files(self.data_files)
        # filter those without positive ctx
        self.data = [r for r in data if len(r["positive_ctxs"]) > 0]
        logger.info("Total cleaned data size: {}".format(len(self.data)))

    def __getitem__(self, index) -> BiEncoderSample:
        json_sample = self.data[index]
        r = BiEncoderSample()
        r.query = self._process_query(json_sample["question"])

        positive_ctxs = json_sample["positive_ctxs"]
        negative_ctxs = (
            json_sample["negative_ctxs"] if "negative_ctxs" in json_sample else []
        )
        hard_negative_ctxs = (
            json_sample["hard_negative_ctxs"]
            if "hard_negative_ctxs" in json_sample
            else []
        )

        for ctx in positive_ctxs + negative_ctxs + hard_negative_ctxs:
            if "title" not in ctx:
                ctx["title"] = None

        def create_passage(ctx: dict):
            return BiEncoderPassage(
                normalize_passage(ctx["text"]) if self.normalize else ctx["text"],
                ctx["title"],
            )

        r.positive_passages = [create_passage(ctx) for ctx in positive_ctxs]
        r.negative_passages = [create_passage(ctx) for ctx in negative_ctxs]
        r.hard_negative_passages = [create_passage(ctx) for ctx in hard_negative_ctxs]
        return r

    def __len__(self):
        return len(self.data)

    def get_qas(self) -> Tuple[List[str], List[str]]:
        return [s["question"] for s in self.data], [s["answers"] for s in self.data]

    def get_qas_range(
        self, start_idx: int, end_idx: int
    ) -> Tuple[List[str], List[str]]:
        return (
            [s["question"] for s in self.data[start_idx:end_idx]],
            [s["answers"] for s in self.data[start_idx:end_idx]],
        )


def normalize_passage(ctx_text: str):
    ctx_text = ctx_text.replace("\n", " ").replace("’", "'")
    return ctx_text


def normalize_question(question: str) -> str:
    question = question.replace("’", "'")
    return question


class Cell:
    def __init__(self):
        self.value_tokens: List[str] = []
        self.type: str = ""
        self.nested_tables: List[Table] = []

    def __str__(self):
        return " ".join(self.value_tokens)

    def to_dpr_json(self, cell_idx: int):
        r = {"col": cell_idx}
        r["value"] = str(self)
        return r


class Row:
    def __init__(self):
        self.cells: List[Cell] = []

    def __str__(self):
        return "| ".join([str(c) for c in self.cells])

    def visit(self, tokens_function, row_idx: int):
        for i, c in enumerate(self.cells):
            if c.value_tokens:
                tokens_function(c.value_tokens, row_idx, i)

    def to_dpr_json(self, row_idx: int):
        r = {"row": row_idx}
        r["columns"] = [c.to_dpr_json(i) for i, c in enumerate(self.cells)]
        return r


class Table(object):
    def __init__(self, caption=""):
        self.caption = caption
        self.body: List[Row] = []
        self.key = None
        self.gold_match = False

    def __str__(self):
        table_str = "<T>: {}\n".format(self.caption)
        table_str += " rows:\n"
        for i, r in enumerate(self.body):
            table_str += " row #{}: {}\n".format(i, str(r))

        return table_str

    def get_key(self) -> str:
        if not self.key:
            self.key = str(self)
        return self.key

    def visit(self, tokens_function, include_caption: bool = False) -> bool:
        if include_caption:
            tokens_function(self.caption, -1, -1)
        for i, r in enumerate(self.body):
            r.visit(tokens_function, i)

    def to_dpr_json(self):
        r = {
            "caption": self.caption,
            "rows": [r.to_dpr_json(i) for i, r in enumerate(self.body)],
        }
        if self.gold_match:
            r["gold_match"] = 1
        return r


class NQTableParser(object):
    def __init__(self, tokens, is_html_mask, title):
        self.tokens = tokens
        self.is_html_mask = is_html_mask
        self.max_idx = len(self.tokens)
        self.all_tables = []

        self.current_table: Table = None
        self.tables_stack = collections.deque()
        self.title = title

    def parse(self) -> List[Table]:
        self.all_tables = []
        self.tables_stack = collections.deque()

        for i in range(self.max_idx):

            t = self.tokens[i]

            if not self.is_html_mask[i]:
                # cell content
                self._on_content(t)
                continue

            if "<Table" in t:
                self._on_table_start()
            elif t == "</Table>":
                self._on_table_end()
            elif "<Tr" in t:
                self._onRowStart()
            elif t == "</Tr>":
                self._onRowEnd()
            elif "<Td" in t or "<Th" in t:
                self._onCellStart()
            elif t in ["</Td>", "</Th>"]:
                self._on_cell_end()

        return self.all_tables

    def _on_table_start(self):
        caption = self.title
        parent_table = self.current_table
        if parent_table:
            self.tables_stack.append(parent_table)

            caption = parent_table.caption
            if parent_table.body and parent_table.body[-1].cells:
                current_cell = self.current_table.body[-1].cells[-1]
                caption += " | " + " ".join(current_cell.value_tokens)

        t = Table()
        t.caption = caption
        self.current_table = t
        self.all_tables.append(t)

    def _on_table_end(self):
        t = self.current_table
        if t:
            if self.tables_stack:  # t is a nested table
                self.current_table = self.tables_stack.pop()
                if self.current_table.body:
                    current_cell = self.current_table.body[-1].cells[-1]
                    current_cell.nested_tables.append(t)
        else:
            logger.error("table end without table object")

    def _onRowStart(self):
        self.current_table.body.append(Row())

    def _onRowEnd(self):
        pass

    def _onCellStart(self):
        current_row = self.current_table.body[-1]
        current_row.cells.append(Cell())

    def _on_cell_end(self):
        pass

    def _on_content(self, token):
        if self.current_table.body:
            current_row = self.current_table.body[-1]
            current_cell = current_row.cells[-1]
            current_cell.value_tokens.append(token)
        else:  # tokens outside of row/cells. Just append to the table caption.
            self.current_table.caption += " " + token


def read_nq_tables_jsonl(path: str) -> Dict[str, Table]:
    tables_with_issues = 0
    single_row_tables = 0
    nested_tables = 0
    regular_tables = 0
    total_tables = 0
    total_rows = 0
    tables_dict = {}

    with jsonlines.open(path, mode="r") as jsonl_reader:
        for jline in jsonl_reader:
            tokens = jline["tokens"]

            if "( hide ) This section has multiple issues" in " ".join(tokens):
                tables_with_issues += 1
                continue

            mask = jline["html_mask"]
            # page_url = jline["doc_url"]
            title = jline["title"]
            p = NQTableParser(tokens, mask, title)
            tables = p.parse()

            # table = parse_table(tokens, mask)

            nested_tables += len(tables[1:])

            for t in tables:
                total_tables += 1

                # calc amount of non empty rows
                non_empty_rows = sum(
                    [
                        1
                        for r in t.body
                        if r.cells and any([True for c in r.cells if c.value_tokens])
                    ]
                )

                if non_empty_rows <= 1:
                    single_row_tables += 1
                else:
                    regular_tables += 1
                    total_rows += len(t.body)

                    if t.get_key() not in tables_dict:
                        tables_dict[t.get_key()] = t

            if len(tables_dict) % 1000 == 0:
                logger.info("tables_dict %d", len(tables_dict))

    logger.info("regular tables %d", regular_tables)
    logger.info("tables_with_issues %d", tables_with_issues)
    logger.info("single_row_tables %d", single_row_tables)
    logger.info("nested_tables %d", nested_tables)
    return tables_dict


def get_table_string_for_answer_check(table: Table):  # this doesn't use caption
    table_text = ""
    for r in table.body:
        table_text += " . ".join([" ".join(c.value_tokens) for c in r.cells])
    table_text += " . "
    return table_text


class JsonLTablesQADataset(Dataset):
    def __init__(
        self,
        file: str,
        is_train_set: bool,
        selector: DictConfig = None,
        shuffle_positives: bool = False,
        max_negatives: int = 1,
        seed: int = 0,
        max_len=100,
        split_type: str = "type1",
    ):
        super().__init__(selector, shuffle_positives=shuffle_positives)
        self.data_files = glob.glob(file)
        self.data = []
        self.is_train_set = is_train_set
        self.max_negatives = max_negatives
        self.rnd = random.Random(seed)
        self.max_len = max_len
        self.linearize_func = JsonLTablesQADataset.get_lin_func(split_type)

    def load_data(self):
        data = []
        for path in self.data_files:
            with jsonlines.open(path, mode="r") as jsonl_reader:
                data += [jline for jline in jsonl_reader]

        # filter those without positive ctx
        self.data = [r for r in data if len(r["positive_ctxs"]) > 0]
        logger.info("Total cleaned data size: {}".format(len(self.data)))

    def __getitem__(self, index) -> BiEncoderSample:
        json_sample = self.data[index]
        r = BiEncoderSample()
        r.query = json_sample["question"]
        positive_ctxs = json_sample["positive_ctxs"]
        hard_negative_ctxs = json_sample["hard_negative_ctxs"]

        if self.shuffle_positives:
            self.rnd.shuffle(positive_ctxs)

        if self.is_train_set:
            self.rnd.shuffle(hard_negative_ctxs)
        positive_ctxs = positive_ctxs[0:1]
        hard_negative_ctxs = hard_negative_ctxs[0 : self.max_negatives]

        r.positive_passages = [
            BiEncoderPassage(self.linearize_func(self, ctx, True), ctx["caption"])
            for ctx in positive_ctxs
        ]
        r.negative_passages = []
        r.hard_negative_passages = [
            BiEncoderPassage(self.linearize_func(self, ctx, False), ctx["caption"])
            for ctx in hard_negative_ctxs
        ]
        return r

    def __len__(self):
        return len(self.data)

    @classmethod
    def get_lin_func(cls, split_type: str):
        f = {
            "type1": JsonLTablesQADataset._linearize_table,
        }
        return f[split_type]

    @classmethod
    def split_table(cls, t: dict, max_length: int):
        rows = t["rows"]
        header = None
        header_len = 0
        start_row = 0

        # get the first non empty row as the "header"
        for i, r in enumerate(rows):
            row_lin, row_len = JsonLTablesQADataset._linearize_row(r)
            if len(row_lin) > 1:  # TODO: change to checking cell value tokens
                header = row_lin
                header_len += row_len
                start_row = i
                break

        chunks = []
        current_rows = [header]
        current_len = header_len

        for i in range(start_row + 1, len(rows)):
            row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i])
            if len(row_lin) > 1:  # TODO: change to checking cell value tokens
                current_rows.append(row_lin)
                current_len += row_len
            if current_len >= max_length:
                # linearize chunk
                linearized_str = "\n".join(current_rows) + "\n"
                chunks.append(linearized_str)
                current_rows = [header]
                current_len = header_len

        if len(current_rows) > 1:
            linearized_str = "\n".join(current_rows) + "\n"
            chunks.append(linearized_str)
        return chunks

    def _linearize_table(self, t: dict, is_positive: bool) -> str:
        rows = t["rows"]
        selected_rows = set()
        rows_linearized = []
        total_words_len = 0

        # get the first non empty row as the "header"
        for i, r in enumerate(rows):
            row_lin, row_len = JsonLTablesQADataset._linearize_row(r)
            if len(row_lin) > 1:  # TODO: change to checking cell value tokens
                selected_rows.add(i)
                rows_linearized.append(row_lin)
                total_words_len += row_len
                break

        # split to chunks
        if is_positive:
            row_idx_with_answers = [ap[0] for ap in t["answer_pos"]]

            if self.shuffle_positives:
                self.rnd.shuffle(row_idx_with_answers)
            for i in row_idx_with_answers:
                if i not in selected_rows:
                    row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i])
                    selected_rows.add(i)
                    rows_linearized.append(row_lin)
                    total_words_len += row_len
                if total_words_len >= self.max_len:
                    break

        if total_words_len < self.max_len:  # append random rows

            if self.is_train_set:
                rows_indexes = np.random.permutation(range(len(rows)))
            else:
                rows_indexes = [*range(len(rows))]

            for i in rows_indexes:
                if i not in selected_rows:
                    row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i])
                    if len(row_lin) > 1:  # TODO: change to checking cell value tokens
                        selected_rows.add(i)
                        rows_linearized.append(row_lin)
                        total_words_len += row_len
                    if total_words_len >= self.max_len:
                        break

        linearized_str = ""
        for r in rows_linearized:
            linearized_str += r + "\n"

        return linearized_str

    @classmethod
    def _linearize_row(cls, row: dict) -> Tuple[str, int]:
        cell_values = [c["value"] for c in row["columns"]]
        total_words = sum(len(c.split(" ")) for c in cell_values)
        return ", ".join([c["value"] for c in row["columns"]]), total_words


def split_tables_to_chunks(
    tables_dict: Dict[str, Table], max_table_len: int, split_type: str = "type1"
) -> List[Tuple[int, str, str, int]]:
    tables_as_dicts = [t.to_dpr_json() for k, t in tables_dict.items()]
    chunks = []
    chunk_id = 0
    for i, t in enumerate(tables_as_dicts):
        # TODO: support other types
        assert split_type == "type1"
        table_chunks = JsonLTablesQADataset.split_table(t, max_table_len)
        title = t["caption"]
        for c in table_chunks:
            # chunk id , text, title, external_id
            chunks.append((chunk_id, c, title, i))
            chunk_id += 1
        if i % 1000 == 0:
            logger.info("Splitted %d tables to %d chunks", i, len(chunks))
    return chunks
