import argparse
import os
import random
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from src.utils.data_utils import get_audio_dataset
from src.utils.prompts_utils import (
    PROMPTS_FOR_DATA_GENERATION_V1,
    PROMPTS_FOR_DATA_GENERATION_V2,
    PROMPTS_FOR_DATA_GENERATION_V3,
    PROMPTS_FOR_DATA_GENERATION_AUDIOSET,
    PROMPTS_FOR_DATA_GENERATION_VOCALSOUND,
    PROMPTS_FOR_DATA_GENERATION_NONSPEECH7K,
)
from src.utils.utils import preprocess_text_for_training, get_text_model_path

# set seed
random.seed(42)

"""
CUDA_VISIBLE_DEVICES=0 python tools/alignment_data_generation.py --dataset_name fleurs
CUDA_VISIBLE_DEVICES=0 python tools/alignment_data_generation.py --dataset_name fleurs --use_instruction
"""

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--text_model_type', type=str, default="qwen2_7b_instruct")
    parser.add_argument('--dataset_name', type=str)
    parser.add_argument('--shard_index', type=int, default=None)
    parser.add_argument('--use_instruction', action='store_true')

    args = parser.parse_args()
    shard_index = args.shard_index
    use_instruction = args.use_instruction

    if use_instruction:
        dataset_type = 'with_instruction'
    else:
        dataset_type = 'without_instruction'

    split = 'train'

    text_model_type = args.text_model_type
    dataset_name = args.dataset_name

    # choose prompts according to dataset
    if dataset_name in ["word_less_frequent", ]:
        PROMPTS = PROMPTS_FOR_DATA_GENERATION_V1
        temporature = 0.0
        audio_type = 'speech'
    elif dataset_name in ["VoiceAssistant-400K", "spoken-web-questions", "DeepDialogue-orpheus", "DeepDialogue-xtts", "mmsu_openbookqa", "openbookqa", ]:
        PROMPTS = PROMPTS_FOR_DATA_GENERATION_V2
        temporature = 0.0
        audio_type = 'speech'
    elif dataset_name in ["libritts_r_filtered", "common_voice_21_0", "EuroSpeech", "fleurs", ]:
        PROMPTS = PROMPTS_FOR_DATA_GENERATION_V3
        temporature = 0.0
        audio_type = 'speech'
    elif dataset_name in ['audioset_strong', 'audioset_wavcaps', ]:
        PROMPTS = PROMPTS_FOR_DATA_GENERATION_AUDIOSET
        temporature = 1.0
        audio_type = 'general'
    elif dataset_name in ['VocalSound', ]:
        PROMPTS = PROMPTS_FOR_DATA_GENERATION_VOCALSOUND
        temporature = 1.0
        audio_type = 'general'
    elif dataset_name in ['Nonspeech7K', ]:
        PROMPTS = PROMPTS_FOR_DATA_GENERATION_NONSPEECH7K
        temporature = 1.0
        audio_type = 'general'
    else:
        raise ValueError(f"Dataset {dataset_name} not supported in alignment data generation.")

    # load text model
    text_model_path = get_text_model_path(text_model_type)
    tag = text_model_path.split('/')[-1]

    text_tokenizer = AutoTokenizer.from_pretrained(text_model_path)
    text_model = LLM(model=text_model_path, tensor_parallel_size=torch.cuda.device_count(), max_model_len=2048, max_num_seqs=16, gpu_memory_utilization=0.95)

    # use greedy decoding
    sampling_params = SamplingParams(temperature=temporature, max_tokens=1024, repetition_penalty=1.0)

    # load audio dataset
    # .select(range(0, 100)) for debug
    dataset = get_audio_dataset(dataset_name=dataset_name, split=split)


    # add an id column so that we can use it as the index
    dataset = dataset.add_column("id", range(len(dataset)))


    # shard if the dataset is too large
    if shard_index is not None:
        dataset = dataset.shard(num_shards=2, index=shard_index)

    print(dataset)


    # remove audio column to accelerate loading
    dataset_without_audio = dataset.remove_columns(["audio"])


    # randomly select a prompt if we are using a prompt
    print(f"dataset_type: {dataset_type}")
    def add_prompt(example):
        if dataset_type == 'with_instruction':
            example["prompt"] = random.choice(PROMPTS)
        else:
            example["prompt"] = ''
        return example
    dataset_without_audio = dataset_without_audio.map(add_prompt)


    # preprocess text to be the same as the training data, so that the response is fully aligned
    def preprocess_text_for_generation(example):
        text = preprocess_text_for_training(example['text'])
        if audio_type == 'speech':   # for speech dataset
            example['text'] = text
        elif audio_type == 'general': # for general audio dataset
            example['text'] = '<audio>' + text + '</audio>'
        else:
            raise ValueError(f"Invalid audio type: {audio_type}")
        return example
    dataset_without_audio = dataset_without_audio.map(preprocess_text_for_generation)


    # filter out the examples that are too long or not first round
    indices = []
    for index in tqdm(range(len(dataset_without_audio)), desc="filtering data that is not the first round..."):
        _use_item = len(dataset_without_audio[index]["text"].split()) < 448
        if 'round' in dataset.column_names:
            _use_item = _use_item and (dataset_without_audio[index]['round'] == 1 or dataset_without_audio[index]['round'] == "None")
        if _use_item:
            indices.append(index)
    dataset_without_audio = dataset_without_audio.select(indices)
    print(dataset_without_audio)


    # print examples
    print(dataset_without_audio[0])
    print(dataset_without_audio[-1])


    questions = []
    for i in tqdm(range(len(dataset_without_audio)), desc='generate questions...'):
        item = dataset_without_audio[i]
        questions.append(f"{item['prompt']}{item['text']}")
        # TODO: currently, we only allow prompts before the text, which is not ideal.


    # construct generation inputs
    prompts = []
    for question in tqdm(questions, desc="Building prompts..."):
        dialogue = [
            {'role': 'user', 'content': f"{question}"},
        ]
        prompt = text_tokenizer.apply_chat_template(dialogue, tokenize=False, add_generation_prompt=True)
        prompts.append(prompt)


    # collect responses
    responses = []
    finish_reasons = []
    batch_size = 128000
    for i in tqdm(range(0, len(prompts), batch_size), desc="Generating..."):
        
        outputs = text_model.generate(prompts[i:i+batch_size], sampling_params)

        for i, output in enumerate(outputs):
            responses.append(output.outputs[0].text)
            finish_reasons.append(output.outputs[0].finish_reason)


    # add responses and finish reasons to dataset
    dataset_without_audio = dataset_without_audio.add_column("response", responses)
    dataset_without_audio = dataset_without_audio.add_column("finish_reason", finish_reasons)


    # save dataset
    if shard_index is not None:
        dataset_without_audio.save_to_disk(f"datasets/response_{dataset_type}/{tag}/{dataset_name}/index_{shard_index}/{split}")
    else:
        dataset_without_audio.save_to_disk(f"datasets/response_{dataset_type}/{tag}/{dataset_name}/{split}")


    # print first and last samples for debugging
    for index in range(10):
        print('-*-' * 10)
        print(dataset_without_audio[index])
        print('-*-' * 10)
        print(dataset_without_audio[-index-1])


if __name__ == '__main__':
    main()
