import sys
import warnings

import torch
from audio_utils.schemas import EditModelWrapper

sys.path.insert(1, "audio_generation/AudioEditingCode/code")
from ddm_inversion.inversion_utils import inversion_forward_process, inversion_reverse_process
from ddm_inversion.ddim_inversion import ddim_inversion, text2image_ldm_stable
from models import load_model
from utils import set_reproducability, load_audio, get_spec

try:
    from torch import inference_mode
except ImportError:
    inference_mode = torch.no_grad


class EditStableAudioOpen(EditModelWrapper):
    def __init__(self, device, steps: int, half_precision=False):
        super().__init__(device=device)

        self.device = device
        self.model_id = "stabilityai/stable-audio-open-1.0"
        self.ldm_stable = load_model(self.model_id, self.device, steps, double_precision=not half_precision)
        self.sample_rate = 44100

    def edit_audio(
        self,
        input_caption: str,
        output_caption: str,
        seed: int,
        cfg_scale: float,
        steps: int,
        negative_output_caption: str,
        init_aud: str = None,
        cfg_tar: float = 12.0,
        tstart: int = 100,
        eta: float = 1.0,
        mode: str = "ours",
        fix_alpha: float = 0.1,
        cutoff_points=None,
        **kwargs
    ) ->torch.Tensor:
        """
        Edits an audio file by partially or fully inverting the diffusion process,
        then re-generating it with a new prompt via the stable-audio-open model.
        """
        set_reproducability(seed, extreme=False)
        with torch.autocast(device_type=self.device):
            # 1) Load audio
            x0, sr, duration = load_audio(
                init_aud,
                self.ldm_stable.get_fn_STFT(),
                device=self.device,
                stft=False,  # For stable-audio-open, we typically set stft=False
                model_sr=self.ldm_stable.get_sr(),
            )

            # Compute skip from tstart
            skip = steps - tstart

            # 2) Forward process (inversion)
            with inference_mode():
                # Encode audio to latent
                w0 = self.ldm_stable.vae_encode(x0)

                if mode == "ddim":
                    if skip != 0:
                        warnings.warn(
                            "Plain DDIM Inversion is usually run with tstart == num_diffusion_steps. "
                            "Now running partial DDIM inversion."
                        )
                    # Single-step cfg_src if you want multi-step, adapt as needed
                    wT = ddim_inversion(
                        self.ldm_stable,
                        w0,
                        [input_caption],  # using single source prompt
                        cfg_scale,
                        num_inference_steps=steps,
                        skip=skip,
                    )

                else:  # mode == "ours"
                    # We'll do 'inversion_forward_process' with single-step cfg
                    _, zs, wts, extra_info = inversion_forward_process(
                        self.ldm_stable,
                        w0,
                        etas=eta,
                        prompts=[input_caption],
                        cfg_scales=[cfg_scale],
                        prog_bar=False,
                        num_inference_steps=steps,
                        cutoff_points=cutoff_points,
                        numerical_fix=True,
                        duration=duration,
                    )

            # 3) Reverse process
            with inference_mode():
                if mode == "ours":
                    # If skip == 0, the forward process was the entire chain.
                    # wts is the final latent after forward.
                    # We'll pass that plus the forward's "zs" to do the reverse.
                    w0, _ = inversion_reverse_process(
                        self.ldm_stable,
                        xT=wts,
                        tstart=torch.tensor(tstart, dtype=torch.int),
                        fix_alpha=fix_alpha,
                        etas=eta,
                        prompts=[output_caption],
                        neg_prompts=[negative_output_caption],
                        cfg_scales=[cfg_tar],
                        prog_bar=False,
                        zs=zs[: (steps - skip)],
                        cutoff_points=cutoff_points,
                        duration=duration,
                        extra_info=extra_info,
                    )
                else:
                    # mode == "ddim"
                    w0 = text2image_ldm_stable(
                        self.ldm_stable,
                        [output_caption],
                        num_inference_steps=steps,
                        cfg_scale=cfg_tar,
                        latents=wT,  # latents from forward pass
                        skip=skip,
                    )

            # 4) Decode final latents -> wave
            with inference_mode():
                x0_dec = self.ldm_stable.vae_decode(w0)

                # The shape is [1, samples] or [1, 1, samples], etc.
                edited_audio = x0_dec.detach().clone().cpu().squeeze(0)

            # Return the final stacked audio

            # make sure that both tensors are the same size
            x0 = x0[..., 0:edited_audio.shape[-1]]
            return torch.stack([x0, edited_audio], dim=0)
