"""
EditDataset adapted from SampleDataset in stable_audio_tools/data/dataset.py
"""
import json
import os
import random
import time
import typing as t

import torch
import torchaudio
from torchaudio import transforms as T

from .utils import PhaseFlipper, PadCrop_Normalized_T, Stereo, Mono


class EditDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        datasets: t.List,
        sample_size=65536,
        sample_rate=48000,
        keywords=None,
        random_crop=False,
        force_channels="stereo"
    ):
        super().__init__()
        self.sample_paths = []

        self.augs = torch.nn.Sequential(
            PhaseFlipper(),
        )

        self.root_paths = []

        self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)

        self.force_channels = force_channels

        self.encoding = torch.nn.Sequential(
            Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
            Mono() if self.force_channels == "mono" else torch.nn.Identity(),
        )

        self.sr = sample_rate

        # self.samples_json = {}  # metadata for each sample
        self.samples = []  # (paths, metadata) for each sample
        for dataset_config in datasets:
            base_path = dataset_config["path"]
            with open(os.path.join(base_path, "samples.json"), "r") as f:
                samples_dict = json.load(f)
                max_items = dataset_config.get("max_items", len(samples_dict.keys()))
                dataset_samples = []
                for uid, metadata in list(samples_dict.items())[:max_items]:
                    dataset_samples += [(os.path.join(base_path, uid), metadata)]
                print(f"Loaded {len(dataset_samples)} samples from {base_path} of id {dataset_config['id']}")
                self.samples.extend(dataset_samples)

        print(f'Using {len(self.samples)} samples')
        print(self.samples[0:10])


    def load_file(self, filename):
        ext = filename.split(".")[-1]
        assert ext == "wav"
        audio, in_sr = torchaudio.load(filename, format=ext)

        if in_sr != self.sr:
            print("resampling")
            resample_tf = T.Resample(in_sr, self.sr)
            audio = resample_tf(audio)

        return audio

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_path, sample_metadata = self.samples[idx]
        audio_names = ["input", "output"]
        audios = []
        info = {}

        try:
            start_time = time.time()
            for audio_name in audio_names:
                audio_path = os.path.join(sample_path, audio_name + ".wav")
                audio = self.load_file(audio_path)
                audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)

                # Run augmentations on this sample (including random crop)
                if self.augs is not None:
                    audio = self.augs(audio)

                audio = audio.clamp(-1, 1)

                # Encode the file to assist in prediction
                if self.encoding is not None:
                    audio = self.encoding(audio)

                audios.append(audio)

                info[audio_name] = {}
                info[audio_name]["path"] = audio_path
                info[audio_name]["timestamps"] = (t_start, t_end)
                info[audio_name]["prompt"] = sample_metadata["data"][audio_name]
                info["seconds_start"] = seconds_start
                info["seconds_total"] = seconds_total
                info["padding_mask"] = padding_mask

            end_time = time.time()

            info["prompt"] = sample_metadata["data"]["instruction"]
            info["load_time"] = end_time - start_time
            info["input_audio"] = audios[0]
            return audios[1], info
        except Exception as e:
            print(f"Couldn't load path {sample_path}: {e}")
            return self[random.randrange(len(self))]
