"""
Streamlit utilities (mostly cached wrappers around riffusion code).
"""
import io
import threading
import typing as T

import pydub
import streamlit as st
import torch
from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from PIL import Image

from riffusion.audio_splitter import AudioSplitter
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams


# TODO(hayk): Add URL params

# DEFAULT_CHECKPOINT = "riffusion/riffusion-model-v1"
DEFAULT_CHECKPOINT = "/root/Audio_tests/riffusion/Riffusion_hf"

AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]

SCHEDULER_OPTIONS = [
    "DPMSolverMultistepScheduler",
    "PNDMScheduler",
    "DDIMScheduler",
    "LMSDiscreteScheduler",
    "EulerDiscreteScheduler",
    "EulerAncestralDiscreteScheduler",
]


@st.cache_resource
def load_riffusion_checkpoint(
    checkpoint: str = DEFAULT_CHECKPOINT,
    no_traced_unet: bool = False,
    device: str = "cuda",
) -> RiffusionPipeline:
    """
    Load the riffusion pipeline.
    """
    return RiffusionPipeline.load_checkpoint(
        checkpoint=checkpoint,
        use_traced_unet=not no_traced_unet,
        device=device,
    )


@st.cache_resource
def load_stable_diffusion_pipeline(
    checkpoint: str = DEFAULT_CHECKPOINT,
    device: str = "cuda",
    dtype: torch.dtype = torch.float16,
    scheduler: str = SCHEDULER_OPTIONS[0],
) -> StableDiffusionPipeline:
    """
    Load the riffusion pipeline.

    TODO(hayk): Merge this into RiffusionPipeline to just load one model.
    """
    if device == "cpu" or device.lower().startswith("mps"):
        print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
        dtype = torch.float32

    pipeline = StableDiffusionPipeline.from_pretrained(
        checkpoint,
        revision="main",
        torch_dtype=dtype,
        safety_checker=lambda images, **kwargs: (images, False),
    ).to(device)

    pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)

    return pipeline


def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
    """
    Construct a denoising scheduler from a string.
    """
    if scheduler == "PNDMScheduler":
        from diffusers import PNDMScheduler

        return PNDMScheduler.from_config(config)
    elif scheduler == "DPMSolverMultistepScheduler":
        from diffusers import DPMSolverMultistepScheduler

        return DPMSolverMultistepScheduler.from_config(config)
    elif scheduler == "DDIMScheduler":
        from diffusers import DDIMScheduler

        return DDIMScheduler.from_config(config)
    elif scheduler == "LMSDiscreteScheduler":
        from diffusers import LMSDiscreteScheduler

        return LMSDiscreteScheduler.from_config(config)
    elif scheduler == "EulerDiscreteScheduler":
        from diffusers import EulerDiscreteScheduler

        return EulerDiscreteScheduler.from_config(config)
    elif scheduler == "EulerAncestralDiscreteScheduler":
        from diffusers import EulerAncestralDiscreteScheduler

        return EulerAncestralDiscreteScheduler.from_config(config)
    else:
        raise ValueError(f"Unknown scheduler {scheduler}")


@st.cache_resource
def pipeline_lock() -> threading.Lock:
    """
    Singleton lock used to prevent concurrent access to any model pipeline.
    """
    return threading.Lock()


@st.cache_resource
def load_stable_diffusion_img2img_pipeline(
    checkpoint: str = DEFAULT_CHECKPOINT,
    device: str = "cuda",
    dtype: torch.dtype = torch.float16,
    scheduler: str = SCHEDULER_OPTIONS[0],
) -> StableDiffusionImg2ImgPipeline:
    """
    Load the image to image pipeline.

    TODO(hayk): Merge this into RiffusionPipeline to just load one model.
    """
    if device == "cpu" or device.lower().startswith("mps"):
        print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
        dtype = torch.float32

    pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
        checkpoint,
        revision="main",
        torch_dtype=dtype,
        safety_checker=lambda images, **kwargs: (images, False),
    ).to(device)

    pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)

    return pipeline


@st.cache_data(persist=True)
def run_txt2img(
    prompt: str,
    num_inference_steps: int,
    guidance: float,
    negative_prompt: str,
    seed: int,
    width: int,
    height: int,
    checkpoint: str = DEFAULT_CHECKPOINT,
    device: str = "cuda",
    scheduler: str = SCHEDULER_OPTIONS[0],
) -> Image.Image:
    """
    Run the text to image pipeline with caching.
    """
    with pipeline_lock():
        pipeline = load_stable_diffusion_pipeline(
            checkpoint=checkpoint,
            device=device,
            scheduler=scheduler,
        )

        generator_device = "cpu" if device.lower().startswith("mps") else device
        generator = torch.Generator(device=generator_device).manual_seed(seed)

        output = pipeline(
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance,
            negative_prompt=negative_prompt or None,
            generator=generator,
            width=width,
            height=height,
        )

        return output["images"][0]


@st.cache_resource
def spectrogram_image_converter(
    params: SpectrogramParams,
    device: str = "cuda",
) -> SpectrogramImageConverter:
    return SpectrogramImageConverter(params=params, device=device)


@st.cache
def spectrogram_image_from_audio(
    segment: pydub.AudioSegment,
    params: SpectrogramParams,
    device: str = "cuda",
) -> Image.Image:
    converter = spectrogram_image_converter(params=params, device=device)
    return converter.spectrogram_image_from_audio(segment)


@st.cache_data
def audio_segment_from_spectrogram_image(
    image: Image.Image,
    params: SpectrogramParams,
    device: str = "cuda",
) -> pydub.AudioSegment:
    converter = spectrogram_image_converter(params=params, device=device)
    return converter.audio_from_spectrogram_image(image)


@st.cache_data
def audio_bytes_from_spectrogram_image(
    image: Image.Image,
    params: SpectrogramParams,
    device: str = "cuda",
    output_format: str = "mp3",
) -> io.BytesIO:
    segment = audio_segment_from_spectrogram_image(image=image, params=params, device=device)

    audio_bytes = io.BytesIO()
    segment.export(audio_bytes, format=output_format)

    return audio_bytes


def select_device(container: T.Any = st.sidebar) -> str:
    """
    Dropdown to select a torch device, with an intelligent default.
    """
    default_device = "cpu"
    if torch.cuda.is_available():
        default_device = "cuda"
    elif torch.backends.mps.is_available():
        default_device = "mps"

    device_options = ["cuda", "cpu", "mps"]
    device = st.sidebar.selectbox(
        "Device",
        options=device_options,
        index=device_options.index(default_device),
        help="Which compute device to use. CUDA is recommended.",
    )
    assert device is not None

    return device


def select_audio_extension(container: T.Any = st.sidebar) -> str:
    """
    Dropdown to select an audio extension, with an intelligent default.
    """
    default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
    extension = container.selectbox(
        "Output format",
        options=AUDIO_EXTENSIONS,
        index=AUDIO_EXTENSIONS.index(default),
    )
    assert extension is not None
    return extension


def select_scheduler(container: T.Any = st.sidebar) -> str:
    """
    Dropdown to select a scheduler.
    """
    scheduler = st.sidebar.selectbox(
        "Scheduler",
        options=SCHEDULER_OPTIONS,
        index=0,
        help="Which diffusion scheduler to use",
    )
    assert scheduler is not None
    return scheduler


def select_checkpoint(container: T.Any = st.sidebar) -> str:
    """
    Provide a custom model checkpoint.
    """
    return container.text_input(
        "Custom Checkpoint",
        value=DEFAULT_CHECKPOINT,
        help="Provide a custom model checkpoint",
    )


@st.cache_data
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
    return pydub.AudioSegment.from_file(audio_file)


@st.cache_resource
def get_audio_splitter(device: str = "cuda"):
    return AudioSplitter(device=device)


@st.cache_resource
def load_magic_mix_pipeline(
    checkpoint: str = DEFAULT_CHECKPOINT,
    device: str = "cuda",
    scheduler: str = SCHEDULER_OPTIONS[0],
):
    pipeline = DiffusionPipeline.from_pretrained(
        checkpoint,
        custom_pipeline="magic_mix",
    ).to(device)

    pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)

    return pipeline


@st.cache
def run_img2img_magic_mix(
    prompt: str,
    init_image: Image.Image,
    num_inference_steps: int,
    guidance_scale: float,
    seed: int,
    kmin: float,
    kmax: float,
    mix_factor: float,
    checkpoint: str = DEFAULT_CHECKPOINT,
    device: str = "cuda",
    scheduler: str = SCHEDULER_OPTIONS[0],
):
    """
    Run the magic mix pipeline for img2img.
    """
    with pipeline_lock():
        pipeline = load_magic_mix_pipeline(
            checkpoint=checkpoint,
            device=device,
            scheduler=scheduler,
        )

        return pipeline(
            init_image,
            prompt=prompt,
            kmin=kmin,
            kmax=kmax,
            mix_factor=mix_factor,
            seed=seed,
            guidance_scale=guidance_scale,
            steps=num_inference_steps,
        )


@st.cache
def run_img2img(
    prompt: str,
    init_image: Image.Image,
    denoising_strength: float,
    num_inference_steps: int,
    guidance_scale: float,
    seed: int,
    negative_prompt: T.Optional[str] = None,
    checkpoint: str = DEFAULT_CHECKPOINT,
    device: str = "cuda",
    scheduler: str = SCHEDULER_OPTIONS[0],
    progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
) -> Image.Image:
    with pipeline_lock():
        pipeline = load_stable_diffusion_img2img_pipeline(
            checkpoint=checkpoint,
            device=device,
            scheduler=scheduler,
        )

        generator_device = "cpu" if device.lower().startswith("mps") else device
        generator = torch.Generator(device=generator_device).manual_seed(seed)

        num_expected_steps = max(int(num_inference_steps * denoising_strength), 1)

        def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None:
            if progress_callback is not None:
                progress_callback(step / num_expected_steps)

        result = pipeline(
            prompt=prompt,
            image=init_image,
            strength=denoising_strength,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            negative_prompt=negative_prompt or None,
            num_images_per_prompt=1,
            generator=generator,
            callback=callback,
            callback_steps=1,
        )

        return result.images[0]


class StreamlitCounter:
    """
    Simple counter stored in streamlit session state.
    """

    def __init__(self, key="_counter"):
        self.key = key
        if not st.session_state.get(self.key):
            st.session_state[self.key] = 0

    def increment(self):
        st.session_state[self.key] += 1

    @property
    def value(self):
        return st.session_state[self.key]


def display_and_download_audio(
    segment: pydub.AudioSegment,
    name: str,
    extension: str = "mp3",
) -> None:
    """
    Display the given audio segment and provide a button to download it with
    a proper file name, since st.audio doesn't support that.
    """
    mime_type = f"audio/{extension}"
    audio_bytes = io.BytesIO()
    segment.export(audio_bytes, format=extension)
    st.audio(audio_bytes, format=mime_type)

    st.download_button(
        f"{name}.{extension}",
        data=audio_bytes,
        file_name=f"{name}.{extension}",
        mime=mime_type,
    )
