import time

import numpy as np
import torch
from loguru import logger
from tqdm import tqdm

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoFeatureExtractor
from transformers.modeling_outputs import BaseModelOutput


@torch.inference_mode()
def generate_response_from_audio(
    audio_array: np.ndarray,
    sampling_rate: int = 16_000,
    whisper_processor: AutoFeatureExtractor = None,
    audio_model: torch.nn.Module = None,
    embed_tokens: torch.nn.Module = None,
    proj_out: torch.nn.Module = None,
    text_model: AutoModelForCausalLM = None,
    text_tokenizer: AutoTokenizer = None,
    ratio: float = 1.0,
    instruction: str = '',
    device: str = "cuda:0",
    dataset_name: str = None,
    only_perform_asr: bool = False,
) -> str:

    if text_model is None:
        print(f"Text model is not provided. Only ASR will be performed!")
        only_perform_asr = True

    # process audio into mel spectrogram features
    input_features = whisper_processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt").input_features
    input_features = input_features.to(device)

    # special tokens are borrowed from the langauge model
    _start_token_id = audio_model.config.vocab_configs['start_token_id']
    _end_token_id = audio_model.config.vocab_configs['end_token_id']

    output_ids = None
    encoder_outputs = None
    past_key_values = None
    output_hidden_states = None

    _input_ids = torch.LongTensor([_start_token_id]).view(1, 1).to(device)

    # autoregressive generation of speech representations
    for _ in range(448):
        outputs = audio_model(
            input_features=input_features,
            decoder_input_ids=_input_ids,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            embed_tokens=embed_tokens,
            proj_out=proj_out,
        )

        past_key_values = outputs.past_key_values

        if encoder_outputs is None:
            encoder_outputs = BaseModelOutput(
                last_hidden_state=outputs.encoder_last_hidden_state,
                hidden_states=outputs.encoder_hidden_states,
                attentions=outputs.encoder_attentions,
            )

        _input_ids = outputs.logits[:, -1:].argmax(dim=-1)

        if output_ids is None:
            output_ids = _input_ids
        else:
            output_ids = torch.cat([output_ids, _input_ids], dim=-1)

        if output_hidden_states is None:
            output_hidden_states = outputs.decoder_last_hidden_state
        else:
            output_hidden_states = torch.cat([output_hidden_states, outputs.decoder_last_hidden_state], dim=1)
        
        if _input_ids[0, -1] == _end_token_id:
            break

    asr_response = text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    if only_perform_asr:
        return asr_response
    
    index = -1
    hidden_states = output_hidden_states[:, :index] # [1, seq_len, d_model]

    if ratio != 1.0:
        _text_embeddings = text_model.get_input_embeddings()(output_ids[:, :index])  # [1, seq_len, d_model]
        hidden_states = ratio * hidden_states + (1 - ratio) * _text_embeddings

    # continue generation with language model
    hidden_states = hidden_states.to(text_model.get_input_embeddings().weight.dtype)

    prompt = f'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{instruction}'
    inputs = text_tokenizer(prompt, return_tensors='pt').to(device)
    outputs = text_model(**inputs, use_cache=True, num_logits_to_keep=1)
    past_key_values = outputs.past_key_values

    outputs = text_model(inputs_embeds=hidden_states, past_key_values=past_key_values, use_cache=True, num_logits_to_keep=1,)
    past_key_values = outputs.past_key_values

    prompt = '<|im_end|>\n<|im_start|>assistant\n'
    inputs = text_tokenizer(prompt, return_tensors='pt').to(device)

    input_ids = None
    _input_ids = inputs.input_ids
    for _ in tqdm(range(1024)):
        outputs = text_model(
            input_ids=_input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            num_logits_to_keep=1,
        )
        
        past_key_values = outputs.past_key_values
        
        logits = outputs.logits
        _input_ids = torch.argmax(logits, dim=-1)

        if input_ids is None:
            input_ids = _input_ids
        else:
            input_ids = torch.cat([input_ids, _input_ids], dim=-1)

        if _input_ids[0, 0].item() == text_tokenizer.eos_token_id:
            break

    infer_response = text_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    # return the asr response and the generated response
    return asr_response, infer_response


def generate_text_chat(client, *args, **kwargs):
    e = ''
    for index in range(25):
        if index > 0:
            time.sleep(10)
            print(f'sleep 10 seconds before try {index + 1} times...')
        try:
            response = client.chat.completions.create(*args, **kwargs)
            time.sleep(0.5)
            if response is None:
                continue
            return response
        except Exception as e:
            logger.info(e)
    return None
