import pdb

import torch
import torch.nn.functional as F


def nllb_generate_one(model, input_ids, decoder_input_ids):
    model_inputs = {
        "input_ids": input_ids,
        "decoder_input_ids": torch.unsqueeze(decoder_input_ids, dim=0),  # 传递当前的输出序列
        "attention_mask": torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device),
    }
    outputs = model(**model_inputs)
    logits = outputs.logits[:, -1, :]  # 获取最后一个位置的logits
    return logits


def mistral_generate_one(model, input_ids):
    # 准备模型的输入
    model_inputs = {
        "input_ids": input_ids
    }
    # 调用模型获取logits
    outputs = model(**model_inputs)
    logits = outputs.logits[:, -1, :]  # 获取最后一个位置的logits
    return logits


def mistral_generate_one_loss(model, input_ids):
    # 准备模型的输入
    model_inputs = {
        "input_ids": input_ids,
        "labels": input_ids
    }
    # 调用模型获取logits
    outputs = model(**model_inputs)
    # odict_keys(['loss', 'logits', 'past_key_values'])

    logits = outputs.logits[:, -1, :]  # 获取最后一个位置的logits
    pdb.set_trace()
    return logits


def nllb_translate(model, tokenizer, source_text, max_length=200, tgt_lang="deu_Latn", device="cuda"):
    # 编码输入文本
    input_ids = tokenizer.encode(source_text, return_tensors="pt").to(device)

    # 初始化一个空的列表来保存生成的输出
    generated_ids = [tokenizer.eos_token_id, tokenizer.lang_code_to_id[tgt_lang]]  # bos_token_id是开始标记的ID

    # 生成过程
    for _ in range(max_length - 1):

        # 准备模型的输入
        model_inputs = {
            "input_ids": input_ids,
            "decoder_input_ids": torch.tensor([generated_ids], device=input_ids.device),  # 传递当前的输出序列
            "attention_mask": torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device),
        }
        # 调用模型获取logits
        outputs = model(**model_inputs)
        logits = outputs.logits[:, -1, :]  # 获取最后一个位置的logits

        # 使用softmax或其他函数转换logits到概率
        probs = F.softmax(logits, dim=-1)

        # 选择下一个token
        # 您可以根据probs选择概率最高的token，或者根据其他策略选择
        next_token_idx = torch.argmax(probs, dim=-1)
        # print(next_token_idx)
        # 将新的token添加到输出序列中
        # pdb.set_trace()
        generated_ids.append(next_token_idx.item())

        # 检查是否应该停止生成（例如，遇到结束标记）
        if next_token_idx.item() == tokenizer.eos_token_id:
            break

        # 解码生成的输出

    translated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)  # 跳过开始标记

    return translated_text


def mistral_translate(model, tokenizer, source_text, max_length, device):
    input_ids = tokenizer(source_text, return_tensors="pt").input_ids.to(device)
    LLM_input_ids_length = input_ids.shape[1]
    # 生成过程
    for _ in range(max_length - 1):
        # 准备模型的输入
        model_inputs = {
            "input_ids": input_ids
        }
        # 调用模型获取logits
        outputs = model(**model_inputs)
        logits = outputs.logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        next_token_idx = torch.argmax(probs, dim=-1)
        next_token_idx = next_token_idx.view(-1, 1)
        if next_token_idx.item() == tokenizer.eos_token_id or next_token_idx.item() == 13:
            break

        input_ids = torch.cat((input_ids, next_token_idx), dim=1)
        # 检查是否应该停止生成（例如，遇到结束标记）

    translated_text = tokenizer.decode(input_ids[:, LLM_input_ids_length:].tolist()[0],
                                       skip_special_tokens=False)  # 跳过开始标记
    # print(translated_text)
    return translated_text
