import json
import sys

import torch
import torchaudio as ta

from audio_utils.schemas import EditModelWrapper

from stable_audio_tools import create_model_from_config
from stable_audio_tools.models.utils import load_ckpt_state_dict
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools.data.utils import PadCrop_Normalized_T


class SAOInstruct(EditModelWrapper):
    def __init__(self, config_path: str, checkpoint_path: str, device: str):
        super().__init__(device)

        self.model = self._load_sao_model(config_path, checkpoint_path, device)
        self.sample_rate = self.model.sample_rate

    @staticmethod
    def _load_sao_model(config_path: str, checkpoint_path: str, device: str):
        with open(config_path) as f:
            model_config = json.load(f)

        model = create_model_from_config(model_config)
        model.load_state_dict(load_ckpt_state_dict(checkpoint_path))
        return model.to(device)

    def edit_audio(self, instructions: list[str], audio_paths: list[str], cfg_scale = 1.5, steps: int = 100, seed: int = 1, use_init_audio=False, init_noise_level: float = 1.0) -> list:
        assert len(instructions) == len(audio_paths), f"Instructions ({len(instructions)}) and audio_paths ({len(audio_paths)}) must have same length"

        if use_init_audio:
            assert len(instructions) == 1, "Only one instruction is allowed when using init audio"

        pad_crop = PadCrop_Normalized_T(2097152, self.model.sample_rate, randomize=False)

        # prepare conditionings
        conditionings = []
        init_audios = []
        for instruction, audio_path in zip(instructions, audio_paths):
            # convert path to tensor
            audio, in_sr = ta.load(audio_path)
            if in_sr != self.model.sample_rate:
                audio = ta.functional.resample(audio, orig_freq=in_sr, new_freq=self.model.sample_rate)

            init_audios.append((self.model.sample_rate, audio))

            orig_audio_length = round(audio.shape[-1] / self.model.sample_rate)
            # print(f"Original audio length: {orig_audio_length}")

            audio = audio.to(self.device)
            audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(audio)
            audio = audio.clamp(min=-1, max=1)

            if audio.shape[0] == 1:
                audio = audio.repeat(2, 1)
            elif audio.shape[0] > 2:
                audio = audio[0:2, ...]


            conditionings.append({
                "prompt": instruction,
                "seconds_start": 0,
                "seconds_total": orig_audio_length,
                "input_audio": audio,
            })

        with torch.no_grad():
            audio_clips = generate_diffusion_cond(
                self.model,
                steps=steps,
                cfg_scale=cfg_scale,
                conditioning=conditionings,
                batch_size=len(conditionings),  # init_audio
                sampler_type="dpmpp-3m-sde",
                init_audio=init_audios[0] if use_init_audio else None,
                init_noise_level=init_noise_level,
                device=self.device,
                seed=seed
            ).cpu()

        audios = []
        for i in range(audio_clips.shape[0]):
            audios.append(audio_clips[i][..., 0:conditionings[i]["seconds_total"] * self.model.sample_rate])

        return audios
