import multiprocessing as mp
import os
from pathlib import Path

import pandas as pd
import torch
import torchaudio
from asr.dictionary import Dictionary
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

# for running on single GPU
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
SAMPLE_RATE = 16000
HALF_BATCHSIZE_TIME = 2000


def identity(x):
    return x


class DataloaderFactory:
    def __init__(self, args, train_args):
        self.args = args
        self.train_args = train_args

    def build(self, decoder, vocab_size, state: str = "train", rank: int = 0, split=None, size=None):
        if state == "train":
            bs = self.train_args["batch_size"]
            self.dataset = SequenceDataset(
                vocab_size=vocab_size,
                decoder=decoder,
                split=self.args["train_folder"] if split is None else split,
                bucket_size=bs,
                libri_root=self.args["libri_root"],
                bucket_file=self.args["bucket_file"],
                rank=rank,
                size=size,
            )
        elif state == "val":
            bs = self.train_args["val_batch_size"]
            self.dataset = SequenceDataset(
                vocab_size=vocab_size,
                decoder=decoder,
                split=self.args["eval_folder"] if split is None else split,
                bucket_size=bs,
                libri_root=self.args["libri_root"],
                bucket_file=self.args["bucket_file"],
                rank=rank,
                size=size,
            )
        else:
            bs = self.train_args["val_batch_size"]
            self.dataset = SequenceDataset(
                vocab_size=vocab_size,
                decoder=decoder,
                split=self.args["test_folder"] if split is None else split,
                bucket_size=bs,
                libri_root=self.args["libri_root"],
                bucket_file=self.args["bucket_file"],
                rank=rank,
                size=size,
            )
        collate_fn = self.dataset.collate_fn
        sampler = DistributedSampler(self.dataset, shuffle=state == "train")
        dataloader = DataLoader(
            dataset=self.dataset,
            batch_size=1,
            drop_last=False,
            num_workers=self.args.num_workers,
            collate_fn=identity,
            sampler=sampler,
            pin_memory=True,
            multiprocessing_context=mp.get_context("fork"),
        )
        # return dataloader
        return DistributedDalaloaderWrapper(dataloader, collate_fn)


class DistributedDalaloaderWrapper:
    def __init__(self, dataloader: DataLoader, collate_fn):
        self.dataloader = dataloader
        self.collate_fn = collate_fn

    def _epoch_iterator(self, it):
        for batch in it:
            yield self.collate_fn(batch)

    def __iter__(self):
        it = iter(self.dataloader)
        return self._epoch_iterator(it)

    def __len__(self):
        return len(self.dataloader)

    @property
    def dataset(self):
        return self.dataloader.dataset

    def set_epoch(self, epoch: int):
        self.dataloader.sampler.set_epoch(epoch)


####################
# Sequence Dataset #
####################
class SequenceDataset(Dataset):
    def __init__(
        self, decoder, vocab_size, split, bucket_size, libri_root, bucket_file, rank, size, dataset_index=None, **kwargs
    ):
        super(SequenceDataset, self).__init__()
        """
        Args:
            split: string
                The name of the dataloader, can be train/dev/test-clean/test-other for asr
            bucket_size: int
                Batch size
            dictionary: Dictionary
            libri_root: str
                LibriSpeech root path
            bucket_file: str
                Path to csv with the length of files
        """

        # self.dictionary = dictionary
        self.libri_root = libri_root
        self.sample_rate = SAMPLE_RATE
        self.split_sets = split
        self.size = size
        self.dictionary = dict(
            zip(decoder.idxs_to_tokens(torch.tensor(list(range(vocab_size))).long()), range(vocab_size))
        )

        # Read table for bucketing
        assert os.path.isdir(bucket_file), (
            "Please first run `python3 preprocess/generate_len_for_bucket.py -h` to " "get bucket file."
        )

        # Wavs
        table_list = []
        if isinstance(self.split_sets, list):
            for item in self.split_sets:
                file_path = os.path.join(bucket_file, item + ".csv")
                if os.path.exists(file_path):
                    table_list.append(pd.read_csv(file_path))
                else:
                    print(f"{item} is not found in bucket_file: {bucket_file}, skipping it.")
                    print(f"Filepath does not exist: {file_path}")
        else:
            file_path = os.path.join(bucket_file, self.split_sets + ".csv")
            if os.path.exists(file_path):
                table_list.append(pd.read_csv(file_path))
            else:
                print(f"{self.split_sets} is not found in bucket_file: {bucket_file}, skipping it.")

        table_list = pd.concat(table_list)
        table_list = table_list.sort_values(by=["length"], ascending=False)

        X = table_list["file_path"].tolist()
        X_lens = table_list["length"].tolist()

        assert len(X) != 0, f"0 data found for {split}"

        # Transcripts
        Y = self._load_transcript(X)
        x_names = set([self._parse_x_name(x) for x in X])
        y_names = set(Y.keys())
        usage_list = list(x_names & y_names)

        Y = {key: Y[key] for key in usage_list}

        self.Y = {k: self.token_to_ids(v) for k, v in Y.items()}

        # Use bucketing to allow different batch sizes at run time
        self.X = []
        self.X_lens = []
        batch_x, batch_len = [], []

        for x, x_len in tqdm(
            zip(X, X_lens),
            total=len(X),
            desc=f"ASR dataset {split}",
            disable=rank != 0,
            dynamic_ncols=True,
        ):
            if self._parse_x_name(x) in usage_list:
                batch_x.append(x)  # filename, str
                batch_len.append(x_len)  # int

                # Fill in batch_x until batch is full
                if len(batch_x) == bucket_size:
                    # Half the batch size if seq too long
                    if (bucket_size >= 2) and (max(batch_len) > HALF_BATCHSIZE_TIME):
                        self.X.append(batch_x[: bucket_size // 2])
                        self.X.append(batch_x[bucket_size // 2 :])
                        self.X_lens.append(batch_len[: bucket_size // 2])
                        self.X_lens.append(batch_len[bucket_size // 2 :])
                    else:
                        self.X.append(batch_x)
                        self.X_lens.append(batch_len)
                    batch_x, batch_len = [], []

        # Gather the last batch
        if len(batch_x) > 1:
            if self._parse_x_name(x) in usage_list:
                self.X.append(batch_x)
                self.X_lens.append(batch_len)
        assert len(self.X_lens) == len(self.X), f"Expected to have the same len of X and X len"

    def token_to_ids(self, tokens) -> torch.IntTensor:
        tokens = tokens.split()
        # tokens_with_letters_separator = []
        # for t in tokens:
        #     tokens_with_letters_separator.append(t)
        #     if t != "|":
        #         tokens_with_letters_separator.append(" ")
        ids = torch.IntTensor(len(tokens))
        for i, token in enumerate(tokens):
            ids[i] = self.dictionary[token]
        return ids

    @staticmethod
    def _parse_x_name(x):
        return x.split("/")[-1].split(".")[0]

    def _load_wav(self, wav_path):
        wav, sr = torchaudio.load(os.path.join(self.libri_root, wav_path))
        assert sr == self.sample_rate, f"Sample rate mismatch: real {sr}, config {self.sample_rate}"
        return wav.view(-1)

    def _load_transcript(self, x_list):
        """Load the transcripts for Librispeech"""

        def process_trans(transcript):
            # TODO: support character / bpe
            transcript = transcript.upper()
            return "| " + " ".join(list(transcript.replace(" ", "|"))) + " |"

        trsp_sequences = {}
        split_spkr_chap_list = list(set("/".join(x.split("/")[:-1]) for x in x_list))

        for dir in split_spkr_chap_list:
            parts = dir.split("/")
            trans_path = f"{parts[-2]}-{parts[-1]}.trans.txt"
            path = os.path.join(self.libri_root, dir, trans_path)
            assert os.path.exists(path), f"Path doesn't exist: {path}"

            with open(path, "r") as trans_f:
                for line in trans_f:
                    lst = line.strip().split()
                    trsp_sequences[lst[0]] = process_trans(" ".join(lst[1:]))

        return trsp_sequences

    def __len__(self):
        return len(self.X) if self.size is None else self.size

    def __getitem__(self, index):
        # Load acoustic feature and pad
        wav_batch = [self._load_wav(x_file).numpy() for x_file in self.X[index]]
        wav_batch_lengths = self.X_lens[index]
        label_batch = [self.Y[self._parse_x_name(x_file)].numpy() for x_file in self.X[index]]
        filename_batch = [Path(x_file).stem for x_file in self.X[index]]
        return (
            wav_batch,
            wav_batch_lengths,
            label_batch,
            filename_batch,
        )  # bucketing, return (wavs, labels)

    def collate_fn(self, items):
        assert len(items) == 1, f"Expected len of items 1, got: {len(items)}"
        return (
            items[0][0],
            items[0][1],
            items[0][2],
            items[0][3],
        )  # hack bucketing, return (wavs, lens, labels, filenames)
