import argparse
import json
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

import pandas as pd 
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoFeatureExtractor

from src.model.audio_model.modeling_alignchat import AlignChatForConditionalGeneration
from src.utils.io_utils import load_and_resample
from src.utils.generation_utils import generate_response_from_audio
from src.utils.utils import MODEL_DIR, OPENAUDIOBENCH_ATTRIBUTES, get_text_model_path, get_generation_output_path

"""
CUDA_VISIBLE_DEVICES=0 python eval/openaudiobench/generate_openaudiobench_alignchat.py --model_path path/to/model --dataset_name alpaca_eval
"""


device = "cuda:0" if torch.cuda.is_available() else "cpu"
sampling_rate = 16_000

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--dataset_name', type=str, default=None)
    parser.add_argument('--prompt_type', type=str, default='')
    parser.add_argument('--ratio', type=float, default=1.0)

    args = parser.parse_args()
    audio_model_path = args.model_path.rstrip('/')
    dataset_name = args.dataset_name
    prompt_type = args.prompt_type
    ratio = args.ratio

    audio_model_type = audio_model_path.split('/')[-3]

    if prompt_type == '':
        instruction = ''
    elif prompt_type == '_ignore':
        instruction = 'Ignore the typos in the following sentence and provide the answer:\n'
    elif prompt_type == '_correct':
        instruction = 'Correct the typos in the following sentence and then provide the answer:\n'
    else:
        raise NotImplementedError(f'Unsupported prompt type: {prompt_type}')

    
    # output path
    result_path = get_generation_output_path(audio_model_path, dataset_name='OpenAudioBench', subset_name=dataset_name, prompt_type=prompt_type, ratio=ratio)


    # load audio model
    audio_model = AlignChatForConditionalGeneration.from_pretrained(audio_model_path, device_map=device)
    embed_tokens = torch.load(os.path.join(MODEL_DIR, f'alignchat/{audio_model_type}/embed_tokens.pt'), map_location=device, weights_only=False)
    embed_tokens = embed_tokens.float()
    embed_tokens.requires_grad_(False)

    proj_out = torch.load(os.path.join(MODEL_DIR, f'alignchat/{audio_model_type}/proj_out.pt'), map_location=device, weights_only=False)
    proj_out = proj_out.float()
    proj_out.requires_grad_(False)
    whisper_processor = AutoFeatureExtractor.from_pretrained(os.path.join(MODEL_DIR, f'alignchat/{audio_model_type}'))


    # load text model
    text_model_type = audio_model_type.split('-')[-1]
    text_model_path = get_text_model_path(text_model_type)

    text_model = AutoModelForCausalLM.from_pretrained(text_model_path, device_map=device, torch_dtype=torch.float16)
    text_tokenizer = AutoTokenizer.from_pretrained(os.path.join(MODEL_DIR, f'alignchat/{audio_model_type}'), use_fast=False)


    # load datasets
    audio_dir = f'datasets/baichuan-inc/OpenAudioBench/eval_datas/{dataset_name}/audios'
    dataset_path = f'datasets/baichuan-inc/OpenAudioBench/eval_datas/{dataset_name}/{dataset_name}.csv'

    datas = list()
    for index, data in pd.read_csv(dataset_path).iterrows():
        data = dict(data)
        data['uuid'] = index
        datas.append(data)


    # generate responses
    attributes = OPENAUDIOBENCH_ATTRIBUTES[dataset_name]
    with open(result_path, 'w', encoding='utf-8') as file:
        for index, item in enumerate(tqdm(datas)):

            # load audio and convert to target sampling rate
            audio_path = os.path.join(audio_dir, item['audio_filename'])
            audio_array, _ = load_and_resample(audio_path, target_sr=sampling_rate)

            asr_response, infer_response = generate_response_from_audio(
                audio_array=audio_array,
                sampling_rate=sampling_rate,
                whisper_processor=whisper_processor,
                audio_model=audio_model,
                embed_tokens=embed_tokens,
                proj_out=proj_out,
                text_model=text_model,
                text_tokenizer=text_tokenizer,
                ratio=ratio,
                instruction=instruction,
                device=device,
                dataset_name=dataset_name,
            )

            print(f"[INDEX] {index}")
            print(f"[QUESTION] {item[attributes[0]]}")
            print(f"[ASR OUT ] {asr_response}")
            print(f"[ANSWER  ] {item[attributes[1]]}")
            print(f"[RESPONSE] {infer_response}")
            print(" -*- -*- -*- ")

            item['asr_response'] = asr_response
            item['infer_response'] = infer_response
            file.write(json.dumps(item, ensure_ascii=False) + '\n')

    print(f"Generation finished for {dataset_name}, congrats!")
    
    # sleep to inspect the results
    if dataset_name == 'web_questions':
        print(f"Sleeping for 1000000 seconds...")
        import time
        time.sleep(1000000)


if __name__ == "__main__":
    main()
