import concurrent
import io
import os
import time

import requests
import torch
from google import genai
from google.genai import types
import soundfile as sf


class AudioAnalysisWrapper:
    def __init__(self):
        pass

    def does_audio_match_prompt(
            self,
            prompts: list,
            audio: torch.Tensor,
            sample_rate: int,
            print_status: bool = False
    ) -> list[bool]:
        pass


class GeminiAudioWrapper(AudioAnalysisWrapper):
    def __init__(self):
        super().__init__()
        self.client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])

    def _generate_with_retries(self, conversation_history, config, max_retries=10, base_delay=1.0, max_delay=30.0):
        retries = 0

        while True:
            try:
                return self.client.models.generate_content(model="gemini-2.0-flash",contents=conversation_history, config=config)
            except (genai.errors.ClientError, genai.errors.ServerError, requests.exceptions.SSLError) as e:
                retries += 1
                if retries > max_retries:
                    print("Exceeded max retries for Gemini API")
                    return None

                wait_time = min(base_delay * (2 ** (retries - 1)), max_delay)
                print(f"Error from Gemini encountered {e}. Retrying {retries}/{max_retries} in {wait_time:.1f}s...")
                time.sleep(wait_time)

    def _send_message(self, caption: str, audio_bytes: bytes, print_status: bool = False) -> bool:
        generation_config = {
            "temperature": 0.7,
            "top_p": 0.95,
            "top_k": 40,
            "max_output_tokens": 2048,
            "response_mime_type": "text/plain",
        }

        conversation_history = [
            {"role": "user", "parts": [types.Part.from_text(
                text=f"Please list all the sound elements and characteristics of the provided audio file. Then give your reasoning if the provided audio matches the caption '{caption}'.")]},
            {"role": "user", "parts": [types.Part.from_bytes(data=audio_bytes, mime_type="audio/wav")]}
        ]

        response_analysis = self._generate_with_retries(conversation_history=conversation_history, config=generation_config)
        if response_analysis is None:
            print(f"Response analysis failed for caption '{caption}'")
            return False

        conversation_history = [
            {"role": "user", "parts": [types.Part.from_text(text=response_analysis.text)]},
            {"role": "user", "parts": [types.Part.from_text(text=f"Based on your analysis and reasoning, give a score between 1 and 10 how similar the audio and the caption '{caption}' are. End your response with a single number that you gave for the similarity.")]}
        ]
        response_reasoning = self._generate_with_retries(conversation_history=conversation_history, config=generation_config)

        if response_analysis is None:
            print(f"Response analysis failed for caption '{caption}'")
            return False

        if response_reasoning is None:
            print(f"Response reasoning failed for caption '{caption}'")
            return False

        try:
            if print_status:
                print(f"***Gemini for caption '{caption}'***")
                print("- Analysis:", response_analysis.text.strip())
                print("- Reasoning:", response_reasoning.text.strip())

            raw_score = response_reasoning.text.strip().rstrip(".")[-2:].strip()
            print(f"- Raw Score:", raw_score)
            score = int(raw_score)
        except (AttributeError, ValueError):
            score = 0

        return score >= 6

    def does_audio_match_prompt(
            self,
            prompts: list,
            audio: torch.Tensor,
            sample_rate: int,
            print_status: bool = False
    ) -> list[bool]:
        assert len(audio.shape) == 3
        assert len(prompts) == audio.shape[0]

        def process_audio(i):
            prompt = prompts[i]
            waveform = audio[i].numpy()
            buffer = io.BytesIO()
            sf.write(file=buffer, data=waveform.T, samplerate=sample_rate, format="WAV")
            audio_bytes = buffer.getvalue()
            return self._send_message(prompt, audio_bytes, print_status=print_status)

        with concurrent.futures.ThreadPoolExecutor() as executor:
            results = executor.map(process_audio, range(len(prompts)))
            matches = list(results)

        return matches
