import os
import sys
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM

root_dir = '/data/home/username/Experiments/LLM_ensemble'
sys.path.insert(0, root_dir)
from src.nllb.Ensembler import Ensembler
from src.nllb.Ensembler_generator import Ensembler_generator
from src.nllb.Demon_prompt_generator import demon_prompt_generator

if __name__ == '__main__':
    device = "cuda:0"
    src_lang = "ron_Latn"
    tgt_lang = "eng_Latn"
    src_lang_full = "Romanian"
    tgt_lang_full = "English"
    mode = "dev"
    translate_direction = src_lang + "-" + tgt_lang+"-4shot"
    learning_rate = float(sys.argv[1])


    NLLB_model_path = "/data/home/username/ModelsHub/facebook/nllb-200-distilled-600M"
    NLLB_tokenizer = AutoTokenizer.from_pretrained(NLLB_model_path, src_lang=src_lang)
    NLLB_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_model_path).to(device)
    NLLB_model.eval()

    LLM_model_path = "/data3/username/ModelsHub/Llama-2-13b-hf"
    LLM_tokenizer = AutoTokenizer.from_pretrained(LLM_model_path)
    LLM_model = AutoModelForCausalLM.from_pretrained(LLM_model_path, torch_dtype="auto").to(device)
    LLM_model.eval()

    LLM_probability_transfer_matrix_path = "/data/home/username/Experiments/LLM_ensemble/probability_transfer_matrix/ablation-anchor-point-count/Llama-2-13b-hf-NLLB-200-distilled-600M/Llama-2-13b-hf.pth"
    NLLB_probability_transfer_matrix_path = "/data/home/username/Experiments/LLM_ensemble/probability_transfer_matrix/ablation-anchor-point-count/Llama-2-13b-hf-NLLB-200-distilled-600M/NLLB-200-distilled-600M.pth"

    print("Start Load Probability Transfer Matrix")
    LLM_probability_transfer_matrix = torch.load(LLM_probability_transfer_matrix_path,
                                                 map_location=device)
    NLLB_probability_transfer_matrix = torch.load(NLLB_probability_transfer_matrix_path,
                                                  map_location=device)
    print("End Load Probability Transfer Matrix")

    task_instruction = f"Translate the sentence from {src_lang_full} to {tgt_lang_full}:"

    demon_prompt = demon_prompt_generator(translate_direction)

    if mode == "dev":
        input_file_path = f"/data/home/username/Experiments/LLM_ensemble/Datasets/Flores/{mode}/{src_lang}.dev"
    else:
        input_file_path = f"/data/home/username/Experiments/LLM_ensemble/Datasets/Flores/dev{mode}/{src_lang}.devtest"

    output_file_path = f"/data/home/username/Experiments/LLM_ensemble/Eval/Flores-{src_lang}-{tgt_lang}/v4-Llama-2-13b-hf-NLLB-200-distilled-600M-{src_lang}-{tgt_lang}-{mode}/{tgt_lang}_{learning_rate}.txt"
    if not os.path.exists(os.path.dirname(output_file_path)):
        os.makedirs(os.path.dirname(output_file_path))

    with open(input_file_path, 'r', encoding="utf-8") as src_file:
        src_contents = src_file.readlines()

        ensembler = Ensembler(LLM_probability_transfer_matrix, NLLB_probability_transfer_matrix)
        ensembler_generator = Ensembler_generator(LLM_model, LLM_tokenizer, NLLB_model, NLLB_tokenizer, ensembler)

        for line in tqdm(src_contents):
            nllb_input_text = line.strip()
            llm_input_text = task_instruction + demon_prompt + f"\n{src_lang_full}:" + nllb_input_text + f"\n{tgt_lang_full}:"
            result = ensembler_generator.mistral_nllb_ensemble_translate(llm_input_text=llm_input_text,
                                                                         nllb_input_text=nllb_input_text,
                                                                         nllb_tgt_lang=tgt_lang,
                                                                         learning_rate=learning_rate)
            with open(output_file_path, "a+", encoding="utf-8") as f_result:
                f_result.write(result + "\n")
