import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import json

import torch
from datasets import load_dataset
from datasets.features import Audio
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoFeatureExtractor, WhisperForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput

from src.model.audio_model.configuration_alignchat import AlignChatConfig
from src.model.audio_model.modeling_alignchat import AlignChatForConditionalGeneration
from src.utils.utils import EXPERIMENT_BASE_DIR, DATASET_BASE_DIR, CACHE_DIR, QWEN2_START_TOKEN, QWEN2_END_TOKEN, get_text_model_path, get_whisper_model_path

device = "cuda:0" if torch.cuda.is_available() else "cpu"


"""
CUDA_VISIBLE_DEVICES=0 python tools/initialize_alignchat.py
"""

# === configuration ===
audio_model_type = "large_v3_turbo-qwen2_7b_instruct"

whisper_model_type, text_model_type = audio_model_type.split('-')

whisper_model_path = get_whisper_model_path(whisper_model_type)
text_model_path = get_text_model_path(text_model_type)
# ===


# load Whisper model that is used to initialize
whisper_processor = AutoFeatureExtractor.from_pretrained(whisper_model_path)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_path, device_map="auto", torch_dtype="float16")
whisper_model.config.forced_decoder_ids = None

# create AlignChat model
with open(f'configs/model_configs/alignchat/{audio_model_type}.json', mode='r', encoding='utf-8') as file:
    config = json.load(file)

config = AlignChatConfig(**config)
audio_model = AlignChatForConditionalGeneration(config).to(dtype=torch.float16, device=device)

# copy Whisper model weights to AlignChat model
whisper_model_state_dict = whisper_model.state_dict()
model_state_dict = audio_model.state_dict()
for key in whisper_model_state_dict:
    if key not in model_state_dict:
        continue
    if whisper_model_state_dict[key].shape != model_state_dict[key].shape:
        print(key, whisper_model_state_dict[key].shape, model_state_dict[key].shape)
        whisper_model_state_dict[key] = model_state_dict[key]

_missing_keys, _unexpected_keys = audio_model.load_state_dict(whisper_model_state_dict, strict=False)

print("missing keys:", _missing_keys)
print("unexpected keys:", _unexpected_keys)


# load text model
text_model = AutoModelForCausalLM.from_pretrained(text_model_path, device_map=device, torch_dtype="float16")
tokenizer = AutoTokenizer.from_pretrained(text_model_path, use_fast=False)
for parameter in text_model.parameters():
    parameter.requires_grad = False


# update config
audio_model.config.vocab_configs['start_token_id'] = tokenizer.convert_tokens_to_ids(QWEN2_START_TOKEN)
audio_model.config.vocab_configs['end_token_id']   = tokenizer.convert_tokens_to_ids(QWEN2_END_TOKEN)
audio_model.config.eos_token_id = audio_model.config.vocab_configs['end_token_id']


# load audio dataset
dataset_path = f'{DATASET_BASE_DIR}/libritts_r_filtered'
split = 'dev.clean'

dataset = load_dataset(dataset_path, name='clean', split=split, cache_dir=CACHE_DIR)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))  # Cast audio files to correct format

sample = dataset[0]["audio"]
print(dataset[0])


with torch.inference_mode():
    # borrow the embedding layer
    embed_tokens = text_model.get_input_embeddings()
    proj_out = text_model.get_output_embeddings()

    input_features = whisper_processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
    input_features = input_features.to(dtype=torch.float16, device=device)

    _start_token_id = audio_model.config.vocab_configs['start_token_id']
    _end_token_id = audio_model.config.vocab_configs['end_token_id']

    output_ids = None
    encoder_outputs = None
    past_key_values = None

    _input_ids = torch.LongTensor([_start_token_id]).view(1, 1).to(device)
    
    # generate
    for i in range(32):
        outputs = audio_model(
            input_features=input_features,
            decoder_input_ids=_input_ids,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            embed_tokens=embed_tokens,
            proj_out=proj_out,
        )

        past_key_values = outputs.past_key_values

        if encoder_outputs is None:
            encoder_outputs = BaseModelOutput(
                last_hidden_state=outputs.encoder_last_hidden_state,
                hidden_states=outputs.encoder_hidden_states,
                attentions=outputs.encoder_attentions,
            )

        _input_ids = outputs.logits[:, -1:].argmax(dim=-1)
        if output_ids is None:
            output_ids = _input_ids
        else:
            output_ids = torch.cat([output_ids, _input_ids], dim=-1)

        if _input_ids[0, -1] == _end_token_id:
            break

# decode
print(output_ids[0])
print(tokenizer.decode(output_ids[0]))


# save model
whisper_processor.save_pretrained(f'{EXPERIMENT_BASE_DIR}/model_weights/alignchat/{audio_model_type}')
tokenizer.save_pretrained(f'{EXPERIMENT_BASE_DIR}/model_weights/alignchat/{audio_model_type}') # Save tokenizer after whisper_processor, because whisper_processor will overwrite tokenizer
audio_model.save_pretrained(f'{EXPERIMENT_BASE_DIR}/model_weights/alignchat/{audio_model_type}')
torch.save(embed_tokens, f'{EXPERIMENT_BASE_DIR}/model_weights/alignchat/{audio_model_type}/embed_tokens.pt')
torch.save(proj_out, f'{EXPERIMENT_BASE_DIR}/model_weights/alignchat/{audio_model_type}/proj_out.pt')
