import json
import sys
import typing as t

import torch

from audio_utils.schemas import ModelWrapper


sys.path.insert(1, "audio_generation/p2p/stable-audio-tools")
from stable_audio_tools import create_model_from_config, get_pretrained_model
from stable_audio_tools.models.utils import load_ckpt_state_dict
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools.models.transformer import Attention


class P2PStableAudioOpen(ModelWrapper):

    def __init__(self, device: str, model_path: str = None, config_path: str = None, half_precision: bool = False):
        super().__init__(device)
        if model_path is None or config_path is None:
            print("Loading default SAO model...")
            model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
        else:
            print(f"Loading SAO model from {model_path} with config from {config_path}...")
            with open(config_path) as f:
                model_config = json.load(f)

            model = create_model_from_config(model_config)
            model.load_state_dict(load_ckpt_state_dict(model_path))

        if half_precision:
            model = model.to(torch.float16)

        self.model = model.to(device)
        self.sample_rate = model_config["sample_rate"]
        self.sample_size = model_config["sample_size"]

    @staticmethod
    def _set_p2p_prop_for_model(model: torch.nn.Module, attn_inject_frac: float, attn_inject_delay: float, attn_reweighting: float = 1):
        assert attn_inject_frac + attn_inject_delay <= 1, f"{attn_inject_frac=}, {attn_inject_delay=}"
        attention_layers = [x for x in model.modules() if isinstance(x, Attention)]
        num_attention = len(attention_layers)

        n_inject_attention = round(num_attention * attn_inject_frac)
        n_skip_attention = round(num_attention * attn_inject_delay)

        for attention_layer in attention_layers:
            attention_layer.prompt_to_prompt = False
            attention_layer.attn_reweighting = attn_reweighting

        for i in range(n_inject_attention):
            attention_layers[n_skip_attention + i].prompt_to_prompt = True

        # print(f"attn_inject_frac: {n_inject_attention} ({attn_inject_frac})")
        # print(f"attn_inject_delay: {n_skip_attention} ({attn_inject_delay})")
        # print([x.prompt_to_prompt for x in self.model.modules() if isinstance(x, Attention)])

    def set_p2p_prop(self, attn_inject_frac: float, attn_inject_delay: float, attn_reweighting: float = 1):
        self._set_p2p_prop_for_model(
            model=self.model,
            attn_inject_frac=attn_inject_frac,
            attn_inject_delay=attn_inject_delay,
            attn_reweighting=attn_reweighting
        )

    def generate_edited_audio(self, input_caption: str, output_caption: str, seed: int, cfg_scale: float, steps: int, negative_output_caption: str, negative_input_caption: str = "", **kwargs) -> t.Tuple[int, torch.Tensor]:
        length = 10
        conditioning = [
            {"prompt": input_caption, "seconds_start": 0, "seconds_total": length},
            {"prompt": output_caption, "seconds_start": 0, "seconds_total": length}
        ]

        negative_conditioning = [
            {"prompt": negative_input_caption if negative_input_caption else "", "seconds_start": 0, "seconds_total": length},
            {"prompt": negative_output_caption if negative_output_caption else "", "seconds_start": 0, "seconds_total": length}
        ]
        with torch.autocast(device_type=self.device):
            output = generate_diffusion_cond(
                model=self.model,
                steps=steps,
                cfg_scale=cfg_scale,
                conditioning=conditioning,
                negative_conditioning=negative_conditioning,
                batch_size=len(conditioning),
                sampler_type="dpmpp-3m-sde",
                device=self.device,
                seed=seed
            )
        output = output[..., 0:self.model.sample_rate*length].cpu()
        output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()

        return output

    def generate_audio(
            self,
            conditioning: list,
            negative_conditioning: list,
            seed: int,
            cfg_scale: float,
            steps: int):

        with torch.autocast(device_type=self.device):
            output = generate_diffusion_cond(
                model=self.model,
                steps=steps,
                cfg_scale=cfg_scale,
                conditioning=conditioning,
                negative_conditioning=negative_conditioning,
                batch_size=len(conditioning),
                sampler_type="dpmpp-3m-sde",
                device=self.device,
                seed=seed
            )
        output = output[..., 0:self.model.sample_rate * 10]
        output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()

        return output
