import json
import random
import sys
from typing import Tuple, Iterable, List

import numpy as np
import torch
from datasets import Dataset
from torch.utils.data import Sampler, RandomSampler, BatchSampler, DataLoader, SequentialSampler
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, PreTrainedTokenizerFast

import datasets

import tqdm

class AllInOneBatch(Sampler):

    def __init__(self, source: Dataset):
        self.source = source

    def __iter__(self):
        return iter([[i for i, _ in enumerate(self.source)]])


class TaskSampler(Sampler):
    """
    Example:
    TaskSampler(dataset, SequentialSampler, lambda s: BatchSampler(SequentialSampler(s), batch_size = 2, drop_last=False))

    Returns batches of examples that consist of a single task
    """

    def __init__(self, source, across_task_sampler, within_task_sampler, within_task_sampler_test, task_key="task"):
        self.source = source
        self.task_key = task_key
        self.across_task_sampler = across_task_sampler
        self.within_task_sampler = within_task_sampler
        self.within_task_sampler_test = within_task_sampler_test

        self.task2id = {"train": dict(), "test": dict()}

        for i, el in enumerate(source):
            t2i = self.task2id["train" if el["is_train"] else "test"]
            if el[self.task_key] not in t2i:
                t2i[el[self.task_key]] = []
            t2i[el[self.task_key]].append(i)

        if self.task2id["train"].keys() != self.task2id["test"].keys():
            raise ValueError("There is a task for which there is a train but not a test split or vice versa")

        self.task_ids = list(self.task2id["train"].keys())

    def __iter__(self):
        for task_id in self.across_task_sampler(self.task_ids):
            relevant_train_ids = self.task2id["train"][self.task_ids[task_id]]
            relevant_test_ids = self.task2id["test"][self.task_ids[task_id]]
            for idx_train, idx_test in zip(self.within_task_sampler(relevant_train_ids),
                                           self.within_task_sampler_test(relevant_test_ids)):
                # TODO: we might want to handle train and test differently here...
                if not isinstance(idx_train, list) or not isinstance(idx_test, list):
                    raise ValueError("Expect batches")
                yield [relevant_train_ids[i] for i in idx_train] + [relevant_test_ids[i] for i in idx_test]

    def __len__(self):
        return len(self.task_ids)


def load_tsv(fname, expect_first_line = None, lenient: bool = False):
    with open(fname) as f:
        it = iter(f)
        if expect_first_line is not None:
            first_line = next(it).strip()
            if expect_first_line != first_line:
                if lenient:
                    line = first_line.strip("\n").strip("\r")
                    if line:
                        yield line.split("\t")
                else:
                    raise ValueError(f"First line must be: '{expect_first_line}'")
        for line in it:
            line = line.strip("\n").strip("\r")
            if line:
                yield line.split("\t")

def prepare_meta_dataset(path:str, tokenizer: AutoTokenizer, batch_size: int) -> DataLoader:
    def mapper(examples):
        d = tokenizer(examples["input"])
        if "output" in examples:
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
        return d

    keys = ["input", "output", "task"]
    d = {"is_train": []} | {k: [] for k in keys}
    for row in load_tsv(path, "input\toutput\ttask\tis_train"):
        for x, k in zip(row, keys):
            d[k].append(x)
        d["is_train"].append(int(row[-1]))
    dataset = Dataset.from_dict(d)

    ts = TaskSampler(dataset, RandomSampler,
                     lambda s: BatchSampler(RandomSampler(s), batch_size=batch_size, drop_last=False),
                     lambda s: AllInOneBatch(s))
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output", "task"])
    return DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)



def prepare_task_dataset(path:str, tokenizer: AutoTokenizer, batch_size: int, random_order: bool = True, lenient: bool=False) -> DataLoader:
    def mapper(examples):
        d = tokenizer(examples["input"])
        if "output" in examples:
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
        return d

    keys = ["input", "output"]
    d = {k: [] for k in keys}
    for row in load_tsv(path, "input\toutput", lenient=lenient):
        for x, k in zip(row, keys):
            d[k].append(x)
    dataset = Dataset.from_dict(d)

    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])
    return DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)


def prepare_multi_meta(path:str, tokenizer: AutoTokenizer, train_batch_size: int, test_batch_size: int) -> List[Tuple[DataLoader, DataLoader]]:
    datasets.disable_progress_bar()
    def mapper(examples):
        d = tokenizer(examples["input"], text_target=examples["output"] if "output" in examples else None)
        return d
    task2data = dict()
    for row in load_tsv(path, "input\toutput\ttask\tis_train"):
        input, output, task, is_train = row
        if task not in task2data:
            task2data[task] = {"train": {"input": [], "output": []}, "test": {"input": [], "output": []}}
        if int(is_train):
            d = task2data[task]["train"]
        else:
            d = task2data[task]["test"]

        d["input"].append(input)
        d["output"].append(output)

    dataloaders = []
    for task in tqdm.tqdm(task2data):
        train_data = Dataset.from_dict(task2data[task]["train"])
        test_data = Dataset.from_dict(task2data[task]["test"])
        if len(train_data) == 0:
            raise ValueError(f"Task {task} has no train data")
        if len(test_data) == 0:
            raise ValueError(f"Task {task} has no test data")
        dls = []
        for data, batch_size in zip([train_data, test_data], [train_batch_size, test_batch_size]):
            data = data.map(mapper, batched=True, remove_columns=["input", "output"])
            sampler = BatchSampler(RandomSampler(data), batch_size=batch_size, drop_last=False)
            dataloader = DataLoader(data, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=sampler)
            dls.append(dataloader)
        dataloaders.append(tuple(dls))

    return dataloaders


def fst_to_vector(fst_tokenizer, num_states, fst: List[Tuple[int, str, str, int]]) -> np.array:
    assert len(fst[0]) == 4 or len(fst[0]) == 5

    fst_rep = np.zeros((len(fst), len(fst[0])), dtype=np.int64)
    for j, f in enumerate(fst):
        s, i, o, sp = f[:4]
        assert s < num_states-1 #last state is reserved for padding
        assert sp < num_states-1
        fst_rep[j, 0] = s

        i_encoded = fst_tokenizer(i)["input_ids"]
        assert len(i_encoded) == 1
        fst_rep[j, 1] = i_encoded[0]

        o_encoded = fst_tokenizer(o)["input_ids"]
        assert len(o_encoded) == 1
        fst_rep[j, 2] = o_encoded[0]

        fst_rep[j, 3] = sp

        if len(f) == 5:
            # for final state indicator
            fst_rep[j, 4] = f[4]
    return fst_rep

def batch_fsts(fst_reps: List[np.array], num_states, max_len=None) -> np.array:
    if max_len is None:
        max_len = max(len(x) for x in fst_reps)
    batched_fst_reps = np.zeros((len(fst_reps), max_len, len(fst_reps[0][0])), dtype=np.int64)
    # Set states to a padding index (last state)
    batched_fst_reps[:, :, 0] = num_states - 1
    batched_fst_reps[:, :, 3] = num_states - 1
    for i, x in enumerate(fst_reps):
        for j, f in enumerate(x):
            if max_len is not None and j >= max_len:
                continue
            batched_fst_reps[i, j] = f
    return batched_fst_reps


def load_fst_jsonl(path: str, tokenizer: AutoTokenizer, fst_tokenizer_path: str, batch_size:int, num_states: int, random_order: bool = True,
                   max_len: int = None, max_n:int=None):
    fst_tokenizer = PreTrainedTokenizerFast(tokenizer_file=fst_tokenizer_path)

    def mapper(examples):
        d = tokenizer(examples["input"])
        if "output" in examples:
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
        return d

    data = {"input": [], "output": [], "fst_rep": [], "task_ids": []}
    with open(path) as f:
        for i, line in enumerate(f):
            d = json.loads(line)
            data["input"].append(d["input"])
            data["output"].append(d["output"])

            if "task_id" in d:
                data["task_ids"].append(d["task_id"])

            data["fst_rep"].append(fst_to_vector(fst_tokenizer, num_states, d["FST"]))

            if max_n is not None and i > max_n:
                break

    if len(data["task_ids"]) == 0:
        del data["task_ids"]

    dataset = Dataset.from_dict(data)
    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])

    seq2seq_collator = DataCollatorForSeq2Seq(tokenizer)
    def collator_fn(features):
        fst_reps = []
        for x in features:
            fst_reps.append(x["fst_rep"])
            del x["fst_rep"]
        d = seq2seq_collator(features)
        d["fst_rep"] = torch.from_numpy(batch_fsts(fst_reps, num_states, max_len=max_len))

        if "task_id" in features[0]:
            d["task_ids"] = torch.from_numpy(np.array([x["task_id"] for x in features]))

        return d

    return DataLoader(dataset, collate_fn=collator_fn, batch_sampler=ts)



class RandomSplit:

    def __init__(self, path: str, tokenizer: AutoTokenizer, num_train:int, train_batch_size, test_batch_size = None, lenient=True):
        def mapper(examples):
            d = tokenizer(examples["input"])
            if "output" in examples:
                d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
            return d

        keys = ["input", "output"]
        data = []
        for row in load_tsv(path, "input\toutput", lenient=lenient):
            data.append(row)
        print("Random number to verify seed", random.randint(0, 100_000_000), file=sys.stderr)
        random.shuffle(data)
        train_data = data[:num_train]
        rest_data = data[num_train:]

        train_dataset = Dataset.from_list([ {k: v for k,v in zip(keys, row)} for row in train_data])
        rest_dataset = Dataset.from_list([ {k: v for k,v in zip(keys, row)} for row in rest_data])

        sampler = SequentialSampler(train_dataset)
        ts = BatchSampler(sampler, batch_size=train_batch_size, drop_last=False)
        dataset = train_dataset.map(mapper, batched=True, remove_columns=["input", "output"])
        self.train_loader = DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)


        sampler = SequentialSampler(rest_dataset)
        ts = BatchSampler(sampler, batch_size=train_batch_size if test_batch_size is None else test_batch_size, drop_last=False)
        dataset = rest_dataset.map(mapper, batched=True, remove_columns=["input", "output"])
        self.rest_loader = DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)

    def get_train_loader(self):
        return self.train_loader

    def get_rest_loader(self):
        return self.rest_loader




