import torch
import torch.nn.functional as F
import laion_clap
import torchaudio as ta
from audio_utils import metrics

def int16_to_float32(x: torch.Tensor):
    return (x / 32767.0).to(torch.float32)


def float32_to_int16(x: torch.Tensor):
    x = torch.clamp(x, min=-1.0, max=1.0)
    return (x * 32767.0).to(torch.int16)

class CLAPWrapper:
    def __init__(self, clap_path: str, device: str, enable_fusion=False, amodel="HTSAT-tiny"):
        self.clap = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device)
        self.clap.load_ckpt(ckpt=clap_path, verbose=False)

    def audio_embed(self, audios: torch.Tensor, sr: int):
        """Resample audio to 48kHz and return CLAP embedding"""
        assert len(audios.shape) == 3  # [B, 2, len]
        audios = ta.functional.resample(waveform=audios, orig_freq=sr, new_freq=48000)
        audios = audios.mean(dim=1)

        audios = int16_to_float32(float32_to_int16(audios))  # quantize
        embed = self.clap.get_audio_embedding_from_data(x=audios, use_tensor=True)
        embed_norm = embed.norm(dim=1)
        assert torch.isclose(embed_norm, torch.ones_like(embed_norm)).all()
        return embed

    def prompt_embed(self, prompts: list[str]):
        """Return CLAP embeddings for a list of prompts"""
        embed = self.clap.get_text_embedding(x=prompts, use_tensor=True)
        embed_norm = embed.norm(dim=1)
        assert torch.isclose(embed_norm, torch.ones_like(embed_norm)).all()
        return embed

    def audio_prompt_sim(self, audios: torch.Tensor, sr: int, prompts: list[str]) -> list[float]:
        """Return the cosine similarity between audio and prompt embeddings"""
        with torch.no_grad():
            audio_embed = self.audio_embed(audios=audios, sr=sr)
            prompt_embed = self.prompt_embed(prompts=prompts)
            sim = F.cosine_similarity(audio_embed, prompt_embed, dim=1)
        return sim.cpu().tolist()

    def audio_prompt_pair_sim(self, audios: torch.Tensor, sample_rate: int, prompts: list[str]) -> dict:
        """Return the cosine similarity between audio and prompt embeddings"""
        assert len(audios.shape) == 3 and audios.shape[0] == 2, f"Audio must be a pair of audio clips but is {audios.shape}"
        assert len(prompts), f"Prompt must be a pair of prompts but is of length {len(prompts)}"

        with torch.no_grad():
            audio_embed = self.audio_embed(audios=audios, sr=sample_rate)
            prompt_embed = self.prompt_embed(prompts=prompts)

            input_audio_embed, output_audio_embed = audio_embed[0], audio_embed[1]
            input_prompt_embed, output_prompt_embed = prompt_embed[0], prompt_embed[1]

            sim_0 = F.cosine_similarity(input_audio_embed, input_prompt_embed, dim=0).item()
            sim_1 = F.cosine_similarity(output_audio_embed, output_prompt_embed, dim=0).item()
            sim_dir = F.cosine_similarity(output_audio_embed - input_audio_embed,
                                          output_prompt_embed - input_prompt_embed, dim=0).item()

            # audio similarity
            mel_loss = metrics.MelSpectrogramLoss()
            sim_clap_pair = F.cosine_similarity(input_audio_embed, output_audio_embed, dim=0).item()
            sim_mel_pair = -mel_loss(x=audios[0], y=audios[1], sr=sample_rate).item()

            return {
                "sim_0": sim_0,
                "sim_1": sim_1,
                "sim_dir": sim_dir,
                "sim_audio_clap": sim_clap_pair,
                "sim_audio_mel": sim_mel_pair
            }