from typing import Optional, Union, Tuple
from collections import OrderedDict
import os

import torch
import numpy as np
from torch.utils.data.dataset import Dataset

POSSIBLE_PHONES = [
    "",  # -1
    # Vowels
    "AA",  # (0) AA0, AA1, AA2, AO0, AO1, AO2,
    "AE",  # (1) AE0, AE1, AE2,
    "AW",  # (2) AW0, AW1, AW2,
    "AY",  # (3) AY0, AY1, AY2,
    "AH",  # (4) AH0, AH1, AH2,
    "EH",  # (5) EH0, EH1, EH2,
    "ER",  # (6) ER0, ER1, ER2,
    "EY",  # (7) EY0, EY1, EY2,
    "IY",  # (8) IY0, IY1, IY2,
    "IH",  # (9) IH0, IH1, IH2,
    "O",  # (10) OW0, OW1, OW2, OY0, OY1, OY2,
    "UH",  # (11) UH0, UH1, UH2,
    "UW",  # (12) UW0, UW1, UW2,
    # LR
    "L",  # (13)
    "R",  # (14)
    # MN
    "M",  # (15)
    "N",  # (16)
    "NG",  # (17)
    # Consonants
    "B",  # (18)
    "D",  # (19)
    "DH",  # (20) FATHER
    "G",  # (21) FORGET
    "K",  # (22)
    "P",  # (23)
    "T",  # (24)
    # H family
    "F",  # (25)
    "CH",  # (26) LECTURE
    "SH",  # (27) SH, ZH  # SELFISH, SUBMERSION
    "TH",  # (28) SUNBATH
    "S",  # (29)
    "Z",  # (30)
    "V",  # (31)
    # else
    "JH",  # (32) FRAGILE
    "W",  # (33)
    "Y",  # (34)
    "HH",  # (35) FREEHAND
]
NUM_PHONES = len(POSSIBLE_PHONES)


class VectorDataset(Dataset):

    def __init__(self,
                 data_dir: str,  # .../hiddens
                 label_dir: str,  # ../alignments
                 layer_id: int = 0,
                 speaker_id: Optional[Union[str, int]] = None,
                 exclude_silence: bool = False,
                 length_min=0,
                 length_max=1000):
        super().__init__()
        self.data_dir = data_dir
        self.label_dir = label_dir
        self.layer_id = layer_id
        self.speaker_id = str(speaker_id) if (speaker_id is not None) else None

        self.dictionary = OrderedDict()  # total 37 phonemes
        for i in range(NUM_PHONES):
            self.dictionary[POSSIBLE_PHONES[i]] = i

        hiddens = []
        for root, dirs, files in os.walk(data_dir):
            for file in files:
                if file[-4:] == ".npy":
                    _key = file.split(".")[0]
                    h_path = os.path.join(root, file)
                    hid = np.load(h_path)  # (num_layers + 1, seq_length, hidden_dim)
                    a_len = hid.shape[1]
                    if not (length_min < a_len < length_max):
                        continue
                    hid = hid[self.layer_id]
                    a_path = os.path.join(label_dir, _key + ".align.txt")

                    if (self.speaker_id is not None) and (self.speaker_id != _key.split("-")[0]):
                        continue  # do not load

                    # read alignment
                    align = []
                    if not os.path.isfile(a_path):
                        continue

                    with open(a_path, "r") as f:
                        for l in f.readlines():
                            ph = l.split(" ")[-1].replace("\n", "")
                            if ph != ph.upper():
                                ph = ""

                            # collapse
                            if ph.startswith("AA") or ph.startswith("AO"):
                                ph = "AA"
                            elif ph.startswith("AE"):
                                ph = "AE"
                            elif ph.startswith("AH"):
                                ph = "AH"
                            elif ph.startswith("AW"):
                                ph = "AW"
                            elif ph.startswith("AY"):
                                ph = "AY"
                            elif ph.startswith("EH"):
                                ph = "EH"
                            elif ph.startswith("ER"):
                                ph = "ER"
                            elif ph.startswith("EY"):
                                ph = "EY"
                            elif ph.startswith("IH"):
                                ph = "IH"
                            elif ph.startswith("IY"):
                                ph = "IY"
                            elif ph.startswith("OW"):
                                ph = "O"
                            elif ph.startswith("OY"):
                                ph = "O"
                            elif ph.startswith("UH"):
                                ph = "UH"
                            elif ph.startswith("UW"):
                                ph = "UW"
                            elif ph == "ZH":
                                ph = "SH"

                            assert ph in POSSIBLE_PHONES
                            align.append(ph)

                    if len(align) > a_len:
                        align = align[:a_len]
                    elif len(align) < a_len:
                        align += [""] * (a_len - len(align))

                    assert len(align) == a_len
                    hiddens.append((_key, hid, align))

        print(f"Total outputs loaded: {len(hiddens)}")
        print(f"Phonemes: {list(self.dictionary.keys())}")

        # make as dataset
        self.data = []
        self.label = []
        for _key, hid, align in hiddens:
            _lab = [self.dictionary[a] for a in align]
            _lab = np.array(_lab, dtype=np.int64)
            if not exclude_silence:
                self.data.append(hid)
                self.label.append(_lab)
            else:
                not_silence = (_lab != 0)
                self.data.append(hid[not_silence])
                self.label.append(_lab[not_silence])

        self.data = torch.from_numpy(np.concatenate(self.data, axis=0))
        self.label = torch.from_numpy(np.concatenate(self.label, axis=0))

        data_mean = torch.mean(self.data, dim=0)
        data_std = torch.std(self.data, dim=0).add_(1e-6)

        # normalize to N(0, 1)
        self.data = self.data.sub_(data_mean).div_(data_std)

        print(f"Data shape: {self.data.shape}, label shape: {self.label.shape}")
        print(f"#Phonemes: {self.num_classes}")

    @property
    def num_classes(self) -> int:
        return len(self.dictionary.keys())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.data[index], self.label[index]
