import time
from functools import partial

import omegaconf
import torch
import numpy
from datasets import load_dataset, load_from_disk
from composer.core import Evaluator
from composer.utils import dist
from torch.utils.data import DataLoader
from io import BytesIO
from urllib.request import urlopen
import librosa
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("models/Qwen2-Audio-7B-Instruct")

# TODO: Currently use the <start_header> stuff as labels as well but probably want to remove them
def build_tokens(audios, feedbacks, feedback_prompt, keep_targets):
    conversations = [
        [{"role": "user", "content": [
            {"type": "audio", "audio_url": 'file://' + audio},
            {"type": "text", "text": feedback_prompt},
        ]},
        {"role": "assistant", "content": [
            {"type": "text", "text": feedback},
        ]}]
        for audio, feedback in zip(audios, feedbacks)
    ]

    texts = [
        processor.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False) 
        for conversation in conversations
    ]
    audios = []
    for conversation in conversations:
        for message in conversation:
            if isinstance(message["content"], list):
                for ele in message["content"]:
                    if ele["type"] == "audio":
                        audios.append(librosa.load(
                            BytesIO(urlopen(ele['audio_url']).read()), 
                            sr=processor.feature_extractor.sampling_rate)[0]
                        )
    inputs = processor(text=texts, audios=audios, return_tensors='pt', padding=True, sampling_rate=processor.feature_extractor.sampling_rate)
    
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    input_features = inputs.input_features
    feature_attention_mask = inputs.feature_attention_mask

    # 构造 labels，如果keep_targets，则labels除了feedback_prompt + feedback需要计算loss，其余部分置为-100；否则，都不算loss
    # Qwen2-audio中传入的labels是未右移的，模型代码中会对labels做右移操作
    if keep_targets:
        labels = input_ids.clone()
        for i, (text, feedback, prompt) in enumerate(zip(texts, feedbacks, [feedback_prompt]*len(feedbacks))):
            # 定位 feedback 的起始位置
            # TODO: target_text是否包含feedback_prompt
            # target_text = prompt + feedback
            target_text = feedback

            tokenized_target = processor.tokenizer(target_text, add_special_tokens=False).input_ids

            input_id_list = input_ids[i].tolist()
            target_len = len(tokenized_target)

            # 在 input_ids 中查找 tokenized_target 的位置
            start_pos = -1
            for j in range(len(input_id_list) - target_len + 1):
                if input_id_list[j:j+target_len] == tokenized_target:
                    start_pos = j
                    break

            if start_pos != -1:
                labels[i, :start_pos] = -100
            else:
                labels[i] = -100  # fallback: 全部忽略，避免错误
    else:
        labels = torch.full_like(input_ids, fill_value=-100)

    return input_ids, attention_mask, labels, input_features, feature_attention_mask

def transform_to_seq_len(data, pad_val, max_seq_len):
    if len(data) > max_seq_len:
        fit_data = data[:max_seq_len]
    else:
        fit_data = data + [pad_val] * (max_seq_len - len(data))
    return torch.tensor(fit_data, dtype=torch.long)

def batch_transform_to_seq_len(batch_data, pad_val, max_seq_len):
    padded = []
    for data in batch_data:
        if len(data) > max_seq_len:
            fit_data = data[:max_seq_len]
        else:
            pad_len = max_seq_len - len(data)
            fit_data = torch.cat([data, torch.full((pad_len,), pad_val, dtype=data.dtype)])
        padded.append(fit_data)
    return torch.stack(padded)



def feedback_collate_fn(
    tokenizer,
    max_seq_len,
    max_audio_len,
    feedback_method,
    cot_prompt,
    data,
):
    """Collator for feedback data.

    Args:
        tokenizer (Tokenzer): The model's tokenizer.
        max_seq_len (int): The maximum sequence length of the model.
        data (dict): The preference data to collate.
    """
    chosen_audios = []
    rejected_audios = []
    
    chosen_feedbacks = []
    rejected_feedbacks = []

    chosen_mos = []
    rejected_mos = []
    
    for sample_idx, sample in enumerate(data):
        chosen_audios.append(sample["chosen"])
        chosen_feedbacks.append(sample["chosen_feedback"])
        rejected_audios.append(sample["rejected"])
        rejected_feedbacks.append(sample["rejected_feedback"])
        chosen_mos.append(sample["chosen_mos"])
        rejected_mos.append(sample["rejected_mos"])
        
    
    keep_targets = feedback_method in ["csft", "teacher"]
        
    # # chosen和rejected分开进行tokenize，会导致seq_len不一致，要么一起build_token，要么padding到统一长度
    # # Build chosen and rejected response tokens
    # chosen_input_ids, chosen_attn_mask, chosen_labels, chosen_features, chosen_feature_attn_mask = build_tokens(chosen_audios, chosen_feedbacks, cot_prompt, keep_targets)
    # rejected_input_ids, rejected_attn_mask, rejected_labels, rejected_features, rejected_feature_attn_mask = build_tokens(rejected_audios, rejected_feedbacks, cot_prompt, keep_targets)
    
    # Build chosen and rejected response tokens
    audios = chosen_audios + rejected_audios
    feedbacks = chosen_feedbacks + rejected_feedbacks
    input_ids, attn_mask, labels, features, feature_attn_mask = build_tokens(audios, feedbacks, cot_prompt, keep_targets)
    
    # # chosen texts truncate/padding
    # chosen_input_ids = batch_transform_to_seq_len(chosen_input_ids, processor.tokenizer.pad_token_id, max_seq_len)
    # chosen_attn_mask = batch_transform_to_seq_len(chosen_attn_mask, 0, max_seq_len)
    # chosen_labels = batch_transform_to_seq_len(chosen_labels, -100, max_seq_len)
    # # chosen features truncate
    # chosen_features = chosen_features[:, :, :max_audio_len]
    # chosen_feature_attn_mask = chosen_feature_attn_mask[:, :, :max_audio_len]
    
    # # rejected texts truncate/padding
    # rejected_input_ids = batch_transform_to_seq_len(rejected_input_ids, processor.tokenizer.pad_token_id, max_seq_len)
    # rejected_attn_mask = batch_transform_to_seq_len(rejected_attn_mask, 0, max_seq_len)
    # rejected_labels = batch_transform_to_seq_len(rejected_labels, -100, max_seq_len)
    # # rejected features truncate
    # rejected_features = rejected_features[:, :, :max_audio_len]
    # rejected_feature_attn_mask = rejected_feature_attn_mask[:, :, :max_audio_len]
    
    # 多此一举的操作，完全可以return，后面传进model时还要cat一下
    chosen_input_ids = input_ids[:len(chosen_audios)]
    chosen_attn_mask = attn_mask[:len(chosen_audios)]
    chosen_labels = labels[:len(chosen_audios)]
    chosen_features = features[:len(chosen_audios)]
    chosen_feature_attn_mask = feature_attn_mask[:len(chosen_audios)]
    
    rejected_input_ids = input_ids[len(chosen_audios):]
    rejected_attn_mask = attn_mask[len(chosen_audios):]
    rejected_labels = labels[len(chosen_audios):]
    rejected_features = features[len(chosen_audios):]
    rejected_feature_attn_mask = feature_attn_mask[len(chosen_audios):]
    
    
    # Force last token to be eos token id regardless
    chosen_input_ids[:, -1] = processor.tokenizer.eos_token_id
    rejected_input_ids[:, -1] = processor.tokenizer.eos_token_id
    
    # Handle labels when overflow seq length
    for sample_idx in range(chosen_labels.shape[0]):
        if chosen_labels[sample_idx, -2] != -100:
            chosen_labels[sample_idx, -2] = processor.tokenizer.eos_token_id
        if rejected_labels[sample_idx, -2] != -100:
            rejected_labels[sample_idx, -2] = processor.tokenizer.eos_token_id
    chosen_labels[:, -1] = -100
    rejected_labels[:, -1] = -100
    
    
    return {
        "chosen_input_ids": chosen_input_ids,
        "chosen_attention_mask": chosen_attn_mask,
        "chosen_lm_labels": chosen_labels,
        "chosen_features": chosen_features,
        "chosen_feature_attention_mask": chosen_feature_attn_mask,
        "rejected_input_ids": rejected_input_ids,
        "rejected_attention_mask": rejected_attn_mask,
        "rejected_lm_labels": rejected_labels,
        "rejected_features": rejected_features,
        "rejected_feature_attention_mask": rejected_feature_attn_mask,
        "chosen_mos": chosen_mos,
        "rejected_mos": rejected_mos,
    }


def build_feedback_dataloader(
    cfg,
    device_batch_size,
    tokenizer,
    feedback_method,
    cot_prompt,
):
    """Build a streaming dataloader for preference data.

    Args:
        cfg (DictConfig): config to initialize the streaming components.
        device_batch_size (int): batch size per device.
        dataset_class (cls): the streaming dataset class to initialize.
        tokenizer (Tokenizer): the model's tokenizer.
        collate_fn: the function used to collate data.
        pad_token (str): the pad token to use.
        left_pad (bool): indiating if we should left pad the sequences.
        add_pad_token (bool): indicating if we should add a pad token to the tokenizer.
    """
    max_seq_len = cfg.dataset.pop("max_seq_len", None)
    max_audio_len = cfg.dataset.pop("max_audio_len", None)
    
    # for debug
    dataset = load_from_disk(cfg.dataset.remote)[cfg.dataset.split]
    # dataset = load_dataset(cfg.dataset.remote, split=cfg.dataset.split)
    dist_sampler = dist.get_sampler(dataset, shuffle=cfg.dataset.shuffle, drop_last=cfg.drop_last)
    dataloader = DataLoader(
        dataset,
        collate_fn=partial(
            feedback_collate_fn, tokenizer, max_seq_len, max_audio_len, feedback_method, cot_prompt
        ),
        sampler=dist_sampler,
        batch_size=device_batch_size,
        num_workers=cfg.num_workers,
        pin_memory=cfg.get('pin_memory', True),
        prefetch_factor=cfg.get('prefetch_factor', 2),
        persistent_workers=cfg.get('persistent_workers', True),
    )
    return dataloader


def build_evaluators(
    eval_loader_config,
    tokenizer,
    device_eval_batch_size,
    feedback_method,
    cot_prompt,
    metric_names
):
    evaluators = []
    assert isinstance(eval_loader_config, omegaconf.ListConfig)

    for i, eval_config in enumerate(eval_loader_config):
        label = eval_config.pop('label', f'eval-{i}')
        eval_dataloader = build_feedback_dataloader(
            eval_config,
            device_eval_batch_size,
            tokenizer,
            feedback_method,
            cot_prompt
        )
        eval_loader = Evaluator(
            label=f'eval/{label}',
            dataloader=eval_dataloader,
            metric_names=metric_names,
            device_eval_microbatch_size=device_eval_batch_size,
        )
        evaluators.append(eval_loader)
    return evaluators


if __name__ == "__main__":

    from transformers import AutoTokenizer
    from cloud.train.train import COT_PROMPT

    sample_data = [{
        "prompt": "What is the capital of the moon?",
        "chosen": "The moon does not have a capital.",
        "rejected": "The moon is made out of cheese.",
        "chosen_feedback": ["This response is correct."],
        "rejected_feedback": ["This response is funny but wrong."],
    }]

    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    ret = feedback_collate_fn(tokenizer, 60, "teacher", COT_PROMPT, sample_data)

    print("CHOSEN TEXT")
    print(tokenizer.decode(ret["chosen_input_ids"][0]))
    print("=" * 100)
    print("CHOSEN LABELS")
    print(tokenizer.decode(ret["chosen_lm_labels"][0][ret["chosen_lm_labels"][0] != -100]))
    print("=" * 100)
    print("REJECTED TEXT")
    print(tokenizer.decode(ret["rejected_input_ids"][0]))
    print("=" * 100)
    print("REJECTED LABELS")
    print(tokenizer.decode(ret["rejected_lm_labels"][0][ret["rejected_lm_labels"][0] != -100]))
    print("=" * 100)
