import glob
import random

import torch
import torchaudio
from torch.utils.data import Dataset
from pathlib import Path

import random
import pdb

from .feature_extractors import FBankGen, SeamlessM4TGen


class FlexiCodecDataset(Dataset):
    def __init__(
        self,
        audio_dir,
        valid_ratio=0.1,
        valid=False,
        sr=24000,
        seg_len=25600,
        audio_processor=None,
        feature_type=None,
        frontend_path=None, 
    ):
        super().__init__()
        random.seed(0)
        self.filenames = audio_dir
        random.shuffle(self.filenames)
        self.sr = sr
        self.feature_type = feature_type

        assert feature_type in ["fbank", "seamless", None]
        self.feature_extractor = None
        if self.feature_type == "fbank":
            self.feature_extractor = FBankGen(sr=self.sr)
        elif self.feature_type == "seamless":
            self.feature_extractor = SeamlessM4TGen(src_sr=self.sr, frontend_path=frontend_path)

        if valid:
            self.filenames = self.filenames[: int(valid_ratio * len(self.filenames))]
        else:
            self.filenames = self.filenames[int(valid_ratio * len(self.filenames)) :]

        self.max_len = seg_len  # 24000

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        ans = torch.zeros(1, self.max_len)

        try:
            audio, sr = torchaudio.load(self.filenames[index])
        except:
            print(f"Error loading {self.filenames[index]}")
            del self.filenames[index]
            return self.__getitem__(index)
        # resample to target sr if not
        if sr != self.sr:
            audio = torchaudio.transforms.Resample(sr, self.sr)(audio)
        # normalize data
        audio = audio / torch.max(torch.abs(audio))
        # for 2-channel audio, if index is odd, return the second channel
        # if index is even, return the first channel
        if audio.shape[0] > 1:
            audio = audio[index % 2].unsqueeze(0)

        # clean unvoiced part
        def remove_silence(self, audio, threshold=20):
            frame_length = int(0.5 * self.sr)  # 500ms frame length
            # sliding window, starting from the back and trim
            for i in range(int(audio.shape[1] // frame_length)):
                frame = audio[:, -frame_length * (i + 1) - 1 : -frame_length * i - 1]
                # print(torch.abs(frame).sum())
                if torch.abs(frame).sum() < threshold:
                    # trim -frame_length * (i + 1): -frame_length * i
                    audio = torch.cat(
                        [
                            audio[:, : -frame_length * (i + 1) - 1],
                            audio[:, -frame_length * i - 1 :],
                        ],
                        dim=1,
                    )
            # pdb.set_trace()
            return audio

        # audio = remove_silence(self, audio)

        if audio.shape[1] > self.max_len:
            st = random.randint(0, audio.shape[1] - self.max_len - 1)
            ed = st + self.max_len

            input_audio = audio[:, st:ed]
        else:
            ans[:, : audio.shape[1]] = audio
            input_audio = ans

        if self.feature_extractor is not None:
            feat = self.feature_extractor.extract_features(input_audio, fs=self.sr)
            return {"audio": input_audio, "features": feat}
        else:
            return input_audio
