import math
import random
from typing import Callable, List, Optional, Sequence, Tuple, Union
import os
import torch
import torchaudio
import json
import warnings
from glob import glob
from torch import Tensor
from torch.utils.data import Dataset
from util import read_txt
from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file
warnings.filterwarnings("ignore", message="Possible clipped samples in output.")

class WAVDataset(Dataset):
    def __init__(
        self,
        path: Union[str, Sequence[str]],
        transforms: Optional[Callable] = None,
        sample_rate: Optional[int] = None,
        random_crop_size: int = None,
        is_training: bool = True,
        requires_mix: bool = False
    ):
        with open(path, 'r') as f:
            self.metadata = json.load(f)
        self.sfxdata = sorted(read_txt("/mnt/localssd/actionkid/silence_list.txt"))
        self.transforms = transforms
        self.sample_rate = sample_rate
        self.random_crop_size = random_crop_size
        self.is_training = is_training
        self.requires_mix = requires_mix
        self.fn_STFT = TacotronSTFT(
        filter_length=1024,
        hop_length=160,
        win_length=1024,
        n_mel_channels=64,
        sampling_rate=16000,
        mel_fmin=0,
        mel_fmax=8000
        )
        assert (
            not random_crop_size or sample_rate
        ), "Optimized random crop requires sample_rate to be set."

    def random_mix(self, crop_size, speech, sfx_path):
        info = torchaudio.info(sfx_path)
        length = info.num_frames
        sample_rate = info.sample_rate
        start_frame = random.randint(0, length - crop_size)
        
        sfx, sfx_sr = torchaudio.load(sfx_path, frame_offset=start_frame, num_frames=crop_size)
        mixed_speech = speech + sfx

        return mixed_speech        

    def optimized_random_crop(self, idx: int) -> Tuple[Tuple[Tensor, int], Tuple[Tensor, int]]:
        # Get length/audio info
        wav_path = os.path.join('/mnt/localssd/actionkid/audio', os.path.basename(self.metadata[idx]["name"]))
        speech_intervals = self.metadata[idx]["speech_timestamps"]
        info = torchaudio.info(wav_path)
        length = info.num_frames
        sample_rate = info.sample_rate
        if self.is_training == False:
            test_wav_path = os.path.join('/mnt/localssd/actionkid/audio', os.path.basename(random.choice(self.metadata)["name"]))
            test_length = torchaudio.info(test_wav_path).num_frames
        # else:
        #     idx_test = random.randint(0, len(self.metadata))
        #     test_wav_path = self.metadata[idx_test]["name"]
        #     non_speech_intervals = self.metadata[idx_test]["non_speech_timestamps"]
        #     test_length = torchaudio.info(test_wav_path).num_frames

        # Calculate correct number of samples to read based on actual
        # and intended sample rate
        ratio = 1 if (self.sample_rate is None) else sample_rate / self.sample_rate
        crop_size = length if (self.random_crop_size is None) else math.ceil(self.random_crop_size * ratio)  # type: ignore

        # Calculate the start frame of the first chunk
        selected_speech_interval = random.choice(speech_intervals)
        start_frame_1 = random.randint(selected_speech_interval['start'], selected_speech_interval['end'])
        start_frame_1 = int(start_frame_1 * ratio)

        # start_frame_1 = random.randint(0, max(length - 2 * crop_size, 0))

        # # Calculate the start frame of the second chunk
        # selected_non_speech_interval = random.choice(non_speech_intervals)
        # start_frame_2 = random.randint(selected_non_speech_interval['start'], min(selected_non_speech_interval['end'] + 1, length - crop_size + 1)) if self.is_training else random.randint(selected_non_speech_interval['start'], min(selected_non_speech_interval['end'] + 1, test_length - crop_size + 1))
        # start_frame_2 = int(start_frame_2 * ratio)

        if self.is_training:
            cond_intervals = []
            if start_frame_1 >= crop_size:
                cond_intervals.append({'start': 0, 'end': start_frame_1 - crop_size})
            if (start_frame_1 + crop_size) <= (length - crop_size):
                cond_intervals.append({'start': start_frame_1 + crop_size, 'end': length - crop_size})
            selected_cond_interval = random.choice(cond_intervals)
            start_frame_2 = random.randint(selected_cond_interval["start"], selected_cond_interval["end"])
        else:
            start_frame_2 = random.randint(0, test_length - crop_size)

        # Calculate the start frame of the second chunk
        # start_frame_2 = start_frame_1 + crop_size + random.randint(0, length - start_frame_1 - 2*crop_size) if self.is_training else random.randint(0, test_length - crop_size)
        # start_frame_2 = int(start_frame_2 * ratio)

        # start_frame_2 = start_frame_1 + crop_size + random.randint(0, length - start_frame_1 - 2*crop_size)

        # Load the samples for the first chunk
        folder_1 = 'separated_unbalanced_audio' if 'unbalanced_audio' in wav_path else 'separated_unbalanced_wav' if 'unbalanced_wav' in wav_path else 'separated_audio'
        waveform_1, sample_rate_1 = torchaudio.load(
            filepath=os.path.join("/".join(wav_path.split('/')[:-2]), folder_1, f"{wav_path.split('/')[-1][:-4]}_speech.wav"), frame_offset=start_frame_1, num_frames=crop_size
        )
        # Randomly mix sound effect with waveform_1, adjusting loudness.
        if random.random() > 0.2 and self.requires_mix:
            waveform_1 = self.random_mix(crop_size, waveform_1, random.choice(self.sfxdata).strip())
        waveform_2, sample_rate_2 = torchaudio.load(filepath=wav_path, frame_offset=start_frame_1, num_frames=crop_size)

        if self.is_training:
            # Load the samples for the second chunk
            waveform_3, sample_rate_3 = torchaudio.load(
                filepath=wav_path, frame_offset=start_frame_2, num_frames=crop_size
            )
        else:
            waveform_3, sample_rate_3 = torchaudio.load(
                filepath=test_wav_path, frame_offset=start_frame_2, num_frames=crop_size
            )

        # Pad with zeroes if the sizes aren't quite right
        # (e.g., rates aren't exact multiples)
        if len(waveform_1[0]) < crop_size:
            waveform_1 = torch.nn.functional.pad(
                waveform_1,
                pad=(0, crop_size - len(waveform_1[0])),
                mode="constant",
                value=0,
            )

        if len(waveform_2[0]) < crop_size:
            waveform_2 = torch.nn.functional.pad(
                waveform_2,
                pad=(0, crop_size - len(waveform_2[0])),
                mode="constant",
                value=0,
            )

        if len(waveform_3[0]) < crop_size:
            waveform_3 = torch.nn.functional.pad(
                waveform_3,
                pad=(0, crop_size - len(waveform_3[0])),
                mode="constant",
                value=0,
            )

        return (waveform_1, sample_rate_1), (waveform_2, sample_rate_2), (waveform_3, sample_rate_3)

    def __getitem__(
        self, idx: int
    ) -> Union[
        Tensor,
        Tuple[Tensor, int],
        Tuple[Tensor, Tensor],
        Tuple[Tensor, List[str], List[str]],
    ]:  # type: ignore

        (waveform_1, sample_rate_1), (waveform_2, sample_rate_2), (waveform_3, sample_rate_3) = self.optimized_random_crop(int(idx))

        # Apply sample rate transform if necessary
        if self.sample_rate and sample_rate_1 != self.sample_rate:
            waveform_1 = torchaudio.transforms.Resample(orig_freq=sample_rate_1, new_freq=self.sample_rate)(waveform_1)
            waveform_1 = waveform_1[:, :self.random_crop_size]

        if self.sample_rate and sample_rate_2 != self.sample_rate:
            waveform_2 = torchaudio.transforms.Resample(orig_freq=sample_rate_2, new_freq=self.sample_rate)(waveform_2)
            waveform_2 = waveform_2[:, :self.random_crop_size]

        if sample_rate_3 != 48000:
            waveform_3 = torchaudio.transforms.Resample(orig_freq=sample_rate_3, new_freq=48000)(waveform_3)

        # Apply other transforms
        if self.transforms:
            waveform_1 = self.transforms(waveform_1)
            waveform_2 = self.transforms(waveform_2)
            waveform_3 = self.transforms(waveform_3)

        mel_1, _, _ = wav_to_fbank(waveform_1[0], target_length=256, fn_STFT=self.fn_STFT)
        mel_2, _, _ = wav_to_fbank(waveform_2[0], target_length=256, fn_STFT=self.fn_STFT)
        return mel_1, mel_2, waveform_3

    def __len__(self) -> int:
        return len(self.metadata)