import spacy
import pandas as pd
from collections import defaultdict
import random
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
import shutil
import os
import subprocess
import json
import logging
import argparse
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

# ---------------------------------------------------------------------------
# Argument parser
# ---------------------------------------------------------------------------
parser = argparse.ArgumentParser(description="TMPC for WMT Translation")
parser.add_argument("--input_file", type=str, required=True, help="Input CSV file for the translation task.")
parser.add_argument("--rm", type=str, default="rl-bandits-lab/translation_rm", help="Reward model HuggingFace path.")
parser.add_argument("--src_language", type=str, default="Japanese", help="Source language name.")
parser.add_argument("--task_language", type=str, default="English", help="Target language name.")
parser.add_argument("--threshold", type=float, default=1.0, help="Reward threshold for buffer admission.")
parser.add_argument("--max_iterations", type=int, default=5, help="Maximum number of iterations.")
parser.add_argument("--good_ref_contexts_num", type=int, default=5, help="Number of good reference contexts to sample.")
parser.add_argument("--cuda_num", type=int, default=0, help="CUDA device index.")
args = parser.parse_args()

TASK_LANGUAGE = args.task_language
SRC_LANGUAGE = args.src_language
csv_path = args.input_file
device = torch.device(f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu")

max_iterations = args.max_iterations
stop_memory = list(range(1, max_iterations))
MEMORY_FOLDER = args.input_file.replace(".csv", "")
THRESHOLD = args.threshold
good_ref_contexts_num = args.good_ref_contexts_num

print(f"SRC_LANGUAGE: {SRC_LANGUAGE}, TASK_LANGUAGE: {TASK_LANGUAGE}, device: {device}")

# ---------------------------------------------------------------------------
# Language configuration
# ---------------------------------------------------------------------------
lang_map = {
    "English": ("en", "en_core_web_sm"),
    "Russian": ("ru", "ru_core_news_sm"),
    "German": ("de", "de_core_news_sm"),
    "Japanese": ("ja", "ja_core_news_sm"),
    "Korean": ("ko", "ko_core_news_sm"),
    "Spanish": ("es", "es_core_news_sm"),
    "Chinese": ("zh", "zh_core_web_sm"),
}


def get_lang_and_nlp(language):
    if language not in lang_map:
        raise ValueError(f"Unsupported language: {language}")
    lang_code, model_name = lang_map[language]
    return lang_code, spacy.load(model_name)


src_lang, src_nlp = get_lang_and_nlp(SRC_LANGUAGE)
tgt_lang, mt_nlp = get_lang_and_nlp(TASK_LANGUAGE)

# ---------------------------------------------------------------------------
# Generation model (local LLaMA-3.1-8B-Instruct)
# ---------------------------------------------------------------------------
GEN_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
if gen_tokenizer.pad_token is None:
    gen_tokenizer.pad_token = gen_tokenizer.eos_token
gen_model = AutoModelForCausalLM.from_pretrained(
    GEN_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map={"": device},
    trust_remote_code=True,
).to(device)
gen_model.eval()


def generate_local(sys_prompt, user_prompt, max_new_tokens=2048, temperature=0.7, top_p=0.9):
    chat = []
    if sys_prompt and sys_prompt.strip():
        chat.append({"role": "system", "content": sys_prompt})
    chat.append({"role": "user", "content": user_prompt})

    inputs = gen_tokenizer.apply_chat_template(
        chat, return_tensors="pt", padding=True, truncation=True
    )
    if isinstance(inputs, torch.Tensor):
        inputs = {"input_ids": inputs.to(device)}
    else:
        inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        out = gen_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=gen_tokenizer.eos_token_id,
        )
    seq = out[0]
    inp_len = inputs["input_ids"].shape[1]
    text = gen_tokenizer.decode(seq[inp_len:], skip_special_tokens=True).strip()
    if text.lower().startswith("assistant"):
        text = text[len("assistant"):].lstrip()
    return text


# ---------------------------------------------------------------------------
# Folder / file processing
# ---------------------------------------------------------------------------
def clear_folder(folder_path):
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)
    os.makedirs(folder_path, exist_ok=True)


def delete_files_with_mt(folder_path):
    if not os.path.exists(folder_path):
        return
    for filename in os.listdir(folder_path):
        if "mt" in filename:
            file_path = os.path.join(folder_path, filename)
            if os.path.isfile(file_path):
                os.remove(file_path)


# ---------------------------------------------------------------------------
# Reward model
# ---------------------------------------------------------------------------
class RewardModel:
    def __init__(self, rm_path, device, torch_dtype=torch.bfloat16):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.RM = AutoModelForCausalLMWithValueHead.from_pretrained(
            rm_path, torch_dtype=torch_dtype
        ).to(self.device)
        self.RM.eval()
        self.RM.gradient_checkpointing_enable()

        value_head_file = hf_hub_download(repo_id=rm_path, filename="value_head.safetensors")
        value_head_weights = load_file(value_head_file)
        new_state_dict = {
            key.replace("v_head.", "") if key.startswith("v_head.") else key: value
            for key, value in value_head_weights.items()
        }
        self.RM.v_head.load_state_dict(new_state_dict)

    def _create_single_message(self, language, source, translation):
        return [
            {"role": "system", "content": "You are a helpful translator and only output the result."},
            {"role": "user", "content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:"},
            {"role": "assistant", "content": translation},
        ]

    def _process_inputs(self, messages):
        input_ids = self.tokenizer.apply_chat_template(
            messages, add_generation_prompt=False, return_tensors="pt", padding=True, truncation=True
        )
        attention_mask = torch.ones_like(input_ids)
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        if len(input_ids.shape) == 1:
            input_ids = input_ids.unsqueeze(0)
            attention_mask = attention_mask.unsqueeze(0)
        return {"input_ids": input_ids, "attention_mask": attention_mask}

    def reward_fn(self, language, source, translations):
        all_rewards = []
        for translation in translations:
            messages = self._create_single_message(language, source, translation)
            inputs = self._process_inputs(messages)
            with torch.no_grad():
                outputs = self.RM(**inputs, return_value=True)
                rewards = outputs[2]
            reward = rewards[0, -1].cpu().item()
            all_rewards.append(reward)
        return all_rewards

    def reward_fn_batch(self, language, src_list, mt_list):
        all_rewards = []
        for src, mt in zip(src_list, mt_list):
            rewards = self.reward_fn(language, src, [mt])
            all_rewards.extend(rewards)
        return all_rewards


reward_model = RewardModel(rm_path=args.rm, device=device)


def batch_rm_find_best_translation(evals, language):
    """Return the best translation (above THRESHOLD) for each source, with its score."""
    src_list = []
    mt_list = []
    counts = []
    for src, translations in evals:
        counts.append(len(translations))
        for mt in translations:
            src_list.append(src)
            mt_list.append(mt)
    rewards = reward_model.reward_fn_batch(language, src_list, mt_list)

    best_translations = []
    index = 0
    for (src, translations), count in zip(evals, counts):
        group_rewards = rewards[index : index + count]
        index += count
        if count < 2:
            if translations:
                best_translations.append((translations[0], group_rewards[0]))
            else:
                best_translations.append((None, None))
        else:
            best_index = group_rewards.index(max(group_rewards))
            best_score = group_rewards[best_index]
            if best_score >= THRESHOLD:
                best_translations.append((translations[best_index], best_score))
            else:
                best_translations.append((None, best_score))
    return best_translations


# ---------------------------------------------------------------------------
# Translation generation (local model)
# ---------------------------------------------------------------------------
def translate_local(source_sentence, buffer, good_sent_size, src_language, tgt_language):
    """Generate multiple translation rollouts conditioned on the subgoal buffer."""
    system_prompts = [
        "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
        "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
        "You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story.",
    ]

    context_prompt = (
        f"Below is a specialized, intermediate translation task. The input text is a mix of {src_language} and partial {tgt_language} translations. "
        f"In the text, some {src_language} sentences are already followed by preliminary {tgt_language} translations enclosed in parentheses. "
        f"These provided translations are rough references \u2013 they may be incomplete, inconsistent, or not fully aligned with the original meaning.\n\n"
        f"Your task is to produce an improved {tgt_language} translation according to the following guidelines:\n"
        f"1. **Refinement:** For sections with existing {tgt_language} translations (in parentheses), refine and polish them so that they are fluent, accurate, and coherent, fully capturing the meaning of the corresponding {src_language} text.\n"
        f"2. **Completion:** For sections that remain untranslated, translate the {src_language} text accurately and naturally in the specified style.\n"
        f"3. **Translation Order and Structure Preservation:** Maintain the original order and structure of the text. Every {src_language} sentence must appear in the same sequence as in the source text, with its corresponding {tgt_language} translation (if available) inserted immediately after it. Do not rearrange or reorder any part of the text.\n"
        f"4. **Consistency:** Ensure a uniform tone and style across the entire translation, adhering to the translator role specified.\n"
        f"5. **Final Output:** Provide the final output as a single, well-structured {tgt_language} text. Do not include any extraneous commentary, explanations, annotations, or headers \u2013 output only the translation in the correct order.\n\n"
        f"Note: This translation is an intermediate version that may later be merged with other translations. Focus on clarity, coherence, and fidelity to the source text.\n"
    )

    processed_source = source_sentence
    if len(buffer) > 0:
        selected_keys = random.sample(list(buffer.keys()), min(len(buffer), good_sent_size))
        for key_sentence in selected_keys:
            key_sentence = key_sentence.strip()
            if key_sentence and (key_sentence in source_sentence):
                translated_sentence = buffer[key_sentence][0][0]
                if f"\n({translated_sentence})\n" not in processed_source:
                    processed_source = processed_source.replace(
                        key_sentence,
                        f"{key_sentence}\n({translated_sentence})\n",
                    )

    context_prompt += f"\nHere is the input data for translation:\n{processed_source}\n\n"
    context_prompt += "Apply the above guidelines to produce an improved, coherent translation that strictly follows the original order of the text.\n"

    if len(buffer) == 0:
        context_prompt = (
            f"### Translate this from {src_language} to {tgt_language} and only output the result."
            f"\n### {src_language}:\n {source_sentence}"
            f"\n### {tgt_language}:\n"
        )

    translations = []
    for prompt in system_prompts:
        translation = generate_local(prompt, context_prompt)
        translations.append(translation)

    return translations


def process_buffer_sentences(source_sentences, buffer):
    translation_map = {}
    for src_key, trans_list in buffer.items():
        if not trans_list or not isinstance(trans_list, list):
            continue
        translation_map[src_key] = trans_list[0]

    translations = []
    for src_sent in source_sentences:
        if src_sent in translation_map and translation_map[src_sent]:
            translations.append(translation_map[src_sent][0])
    return translations


def final_translate_local(source_sentence, source_segments, buffer, src_language, tgt_language):
    """Subgoal-conditioned final re-generation: produce a refined full translation
    by conditioning on validated subgoals from the buffer."""
    translations = process_buffer_sentences(source_segments, buffer)
    initial_translation = "\n".join(translations)

    rewrite_prompt = (
        f"Below is an initial translation of a {src_language} text into {tgt_language}. "
        f"This translation may include omissions, inaccuracies, or awkward phrasing. "
        f"Your task is to produce a refined version that is fluent, accurate, and coherent, "
        f"while faithfully preserving the full meaning of the original {src_language} text.\n\n"
        f"### Instructions:\n"
        f"1. Ensure that every detail in the original {src_language} text is accurately represented.\n"
        f"2. Correct any grammatical errors, unnatural expressions, or inconsistencies.\n"
        f"3. Improve the natural flow so that the translation reads as if written by a native speaker.\n"
        f"4. Do not add, omit, or change any essential details from the source text.\n"
        f"5. Output only the final refined translation without any additional commentary.\n\n"
        f"### Original {src_language} Text:\n{source_sentence}\n\n"
        f"### Initial {tgt_language} Translation:\n{initial_translation}\n\n"
        f"### Refined Translation:"
    )

    translation = generate_local(
        "You are a helpful translator and only output the result.",
        rewrite_prompt,
    )
    return translation


# ---------------------------------------------------------------------------
# Alignment functions
# ---------------------------------------------------------------------------
def save_sentences_to_txt(sentences, filename):
    with open(filename, "w", encoding="utf-8") as f:
        for sentence in sentences:
            f.write(sentence + "\n")


def segment_sentences_by_punctuation(text, lang):
    segmented_sentences = []
    paragraphs = text.split("\n")
    for paragraph in paragraphs:
        if paragraph.strip():
            nlp = src_nlp if lang == src_lang else mt_nlp
            doc = nlp(paragraph)
            for sent in doc.sents:
                segmented_sentences.append(sent.text.strip())
    return segmented_sentences


def generate_overlap_and_embedding(txt_file):
    overlaps_file = txt_file + ".overlaps"
    embed_file = txt_file + ".emb"
    subprocess.run(["./overlap.py", "-i", txt_file, "-o", overlaps_file, "-n", "10"])
    embed_command = ["$LASER/tasks/embed/embed.sh", overlaps_file, embed_file]
    subprocess.run(" ".join(embed_command), shell=True)
    return overlaps_file, embed_file


def compute_alignment_stats(alignment_results):
    costs = []
    zero_cost_count = 0
    for entry in alignment_results:
        try:
            cost = float(entry.split(":")[-1])
            if cost == 0.0:
                zero_cost_count += 1
            else:
                costs.append(cost)
        except ValueError:
            continue
    avg_cost = sum(costs) / len(costs) if costs else 0.0
    zero_cost_ratio = zero_cost_count / len(alignment_results) if alignment_results else 0.0
    return avg_cost, zero_cost_ratio


def run_vecalign_explore(src_txt, tgt_txt, src_embed, tgt_embed):
    """Adaptive penalty search for VecAlign."""
    del_percentile_frac = 0.2
    step_size = 0.005
    prev_zero_cost_ratio = None
    prev_avg_cost = None

    best_avg_cost = float("inf")
    best_del_percentile_frac = del_percentile_frac
    best_zero_cost_ratio = 0.0
    best_alignments = []

    while del_percentile_frac > 0:
        result = subprocess.run(
            [
                "./vecalign.py",
                "--alignment_max_size", "8",
                "--del_percentile_frac", str(del_percentile_frac),
                "--src", src_txt,
                "--tgt", tgt_txt,
                "--costs_sample_size", "200000",
                "--search_buffer_size", "20",
                "--src_embed", src_txt + ".overlaps", src_embed,
                "--tgt_embed", tgt_txt + ".overlaps", tgt_embed,
            ],
            stdout=subprocess.PIPE,
            text=True,
        )

        output_lines = result.stdout.strip().split("\n")
        avg_cost, zero_cost_ratio = compute_alignment_stats(output_lines)
        print(f"del_percentile_frac: {del_percentile_frac:.3f} | Avg Cost: {avg_cost:.6f} | Zero-Cost Ratio: {zero_cost_ratio:.2%}")

        if prev_zero_cost_ratio is not None and prev_zero_cost_ratio != 0 and (zero_cost_ratio / prev_zero_cost_ratio) > 1.5:
            break
        elif prev_zero_cost_ratio is not None and (
            (zero_cost_ratio - prev_zero_cost_ratio) > 0.15
            or avg_cost > prev_avg_cost
            or avg_cost < 0.3
            or zero_cost_ratio > 0.7
        ):
            break
        else:
            if avg_cost < best_avg_cost:
                best_avg_cost = avg_cost
                best_del_percentile_frac = del_percentile_frac
                best_zero_cost_ratio = zero_cost_ratio
                best_alignments = output_lines

        prev_zero_cost_ratio = zero_cost_ratio
        prev_avg_cost = avg_cost
        del_percentile_frac -= step_size

    parsed_alignments = []
    for line in best_alignments:
        if line:
            src_indices, tgt_indices, _ = line.split(":")
            src_indices = list(map(int, src_indices.strip("[]").split(","))) if src_indices.strip("[]") else []
            tgt_indices = list(map(int, tgt_indices.strip("[]").split(","))) if tgt_indices.strip("[]") else []
            parsed_alignments.append((src_indices, tgt_indices))

    print(f"\nBest: del_percentile_frac={best_del_percentile_frac:.3f}, Avg Cost={best_avg_cost:.6f}, Zero-Cost Ratio={best_zero_cost_ratio:.2%}")
    return parsed_alignments


def standardize_common_alignments(common_alignments_list):
    reference_alignments = min(common_alignments_list, key=lambda a: len(a))
    standardized_results = []

    for alignments in common_alignments_list:
        standardized_alignment = []
        mt_idx_map = {tuple(src): mt for src, mt in alignments}
        for src_indices, _ in reference_alignments:
            if tuple(src_indices) in mt_idx_map:
                mt_indices = mt_idx_map[tuple(src_indices)]
            else:
                mt_indices = []
                for src in src_indices:
                    if (src,) in mt_idx_map:
                        mt_indices.extend(mt_idx_map[(src,)])
                mt_indices = sorted(set(mt_indices))
            standardized_alignment.append((src_indices, mt_indices))
        standardized_results.append(standardized_alignment)
    return standardized_results


def generate_windows(source, translations):
    source_segments = segment_sentences_by_punctuation(source, lang=src_lang)

    temp_folder = f"{src_lang}_{tgt_lang}_temp"
    os.makedirs(temp_folder, exist_ok=True)

    src_txt = f"{temp_folder}/tmpc_src.txt"
    mt_txt = f"{temp_folder}/tmpc_mt.txt"

    save_sentences_to_txt(source_segments, src_txt)
    _, src_embed = generate_overlap_and_embedding(src_txt)
    mt_segments_list = [segment_sentences_by_punctuation(t, lang=tgt_lang) for t in translations]

    common_alignments_list = []
    for mt_segments in mt_segments_list:
        save_sentences_to_txt(mt_segments, mt_txt)
        _, mt_embed = generate_overlap_and_embedding(mt_txt)
        src_mt_alignments = run_vecalign_explore(src_txt, mt_txt, src_embed, mt_embed)
        common_alignments_list.append(src_mt_alignments.copy())
        delete_files_with_mt(temp_folder)

    common_alignments_list = standardize_common_alignments(common_alignments_list)

    adjusted_mt_list = []
    for mt_index, common_alignments in enumerate(common_alignments_list):
        adjusted_src = []
        adjusted_mt = []
        for src_indices, mt_indices in common_alignments:
            mt_indices = [x for x in mt_indices if x != -1]
            if len(src_indices) == 0:
                continue
            aligned_src = " ".join([source_segments[i] for i in src_indices])
            aligned_mt = " ".join([mt_segments_list[mt_index][i] for i in mt_indices]) if mt_indices else ""
            adjusted_src.append(aligned_src)
            adjusted_mt.append(aligned_mt)
        adjusted_mt_list.append(adjusted_mt.copy())

    # Remove the temp folder after alignment is done.
    # To inspect intermediate alignment files (overlaps, embeddings), comment out the line below.
    shutil.rmtree(temp_folder, ignore_errors=True)
    return adjusted_src, adjusted_mt_list


# ---------------------------------------------------------------------------
# Main function
# ---------------------------------------------------------------------------
def saving_memory(buffer, index, iteration, final_translations_record):
    os.makedirs(MEMORY_FOLDER, exist_ok=True)
    buffer_file_path = f"{MEMORY_FOLDER}/buffer_{index}_iter_{iteration}.json"
    metadata_file_path = f"{MEMORY_FOLDER}/metadata_{index}_iter_{iteration}.json"

    buffer_to_save = {key: list(value) for key, value in buffer.items()}
    with open(buffer_file_path, "w", encoding="utf-8") as f:
        json.dump(buffer_to_save, f, ensure_ascii=False, indent=4)

    metadata = {"final_translations_record": final_translations_record}
    with open(metadata_file_path, "w", encoding="utf-8") as f:
        json.dump(metadata, f, ensure_ascii=False, indent=4)

    print(f"Saved buffer -> {buffer_file_path}")
    print(f"Saved metadata -> {metadata_file_path}")


def process_chunk():
    data = pd.read_csv(csv_path)
    for index, row in data.iterrows():
        print(f"\n{'='*60} Index {index} {'='*60}")
        buffer = defaultdict(list)

        source_sentence = row[src_lang].replace("\n", " ")
        source_segments = segment_sentences_by_punctuation(source_sentence, lang=src_lang)

        for iteration in range(max_iterations):
            print(f"\nIteration {iteration + 1}/{max_iterations}")

            if iteration in stop_memory:
                final_translations = final_translate_local(
                    source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE
                )
                final_translations_record = [final_translations]
                saving_memory(buffer, index, iteration, final_translations_record)

            if iteration == max_iterations - 1:
                break

            translations = translate_local(
                source_sentence, buffer, good_ref_contexts_num + iteration, SRC_LANGUAGE, TASK_LANGUAGE
            )
            src_windows, mt_windows_list = generate_windows(source_sentence, translations)

            src_context_list = list(src_windows)
            candidates_list = []
            for window_index in range(len(src_windows)):
                candidates = [mt_windows[window_index] for mt_windows in mt_windows_list]
                candidates_list.append(candidates)

            best_candidate_results = batch_rm_find_best_translation(
                list(zip(src_context_list, candidates_list)), TASK_LANGUAGE
            )

            for i, src in enumerate(src_context_list):
                best_tuple = best_candidate_results[i]
                if best_tuple[0] is not None:
                    if src not in buffer:
                        buffer[src] = [best_tuple]
                    else:
                        buffer[src].append(best_tuple)
                    buffer[src].sort(key=lambda x: x[1], reverse=True)

        print(f"Index {index} done.")


if __name__ == "__main__":
    process_chunk()
