#!/usr/bin/env python3
"""
SEGALE-based Context-Level Evaluation for Translation

This script evaluates discourse-level machine translation quality using COMET
with sentence alignment via VecAlign and LASER embeddings (SEGALE framework).

Steps:
  1. Read a CSV containing source, reference, and MT columns.
  2. Segment texts into sentences and generate LASER embeddings.
  3. Run adaptive-penalty VecAlign to align src-ref and src-mt.
  4. Compute COMET scores on aligned sliding windows.
  5. Aggregate paragraph-level scores and save results.

Usage:
  python segale_ctx.py \\
      --file eval_en_ja.csv \\
      --target_column TMPC \\
      --save eval_en_ja \\
      --src_language Chinese \\
      --task_language English \\
      --gpu_id 0
"""

import os
import re
import json
import spacy
import torch
import random
import argparse
import numpy as np
import pandas as pd
import tempfile
import subprocess
import unicodedata
from multiprocessing import Pool
import datetime

# ---------------------------------------------------------------------------
# Language configuration
# ---------------------------------------------------------------------------
LANG_MAP = {
    "English": ("en", "en_core_web_sm"),
    "Russian": ("ru", "ru_core_news_sm"),
    "German": ("de", "de_core_news_sm"),
    "Japanese": ("ja", "ja_core_news_sm"),
    "Korean": ("ko", "ko_core_news_sm"),
    "Spanish": ("es", "es_core_news_sm"),
    "Chinese": ("zh", "zh_core_web_sm"),
}


def get_lang_and_nlp(language):
    if language not in LANG_MAP:
        raise ValueError(f"Unsupported language: {language}. Supported: {list(LANG_MAP.keys())}")
    lang_code, model_name = LANG_MAP[language]
    return lang_code, spacy.load(model_name)


# ---------------------------------------------------------------------------
# Utility Functions
# ---------------------------------------------------------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def normalize_text(text: str) -> str:
    return unicodedata.normalize("NFKC", text)


def segment_sentences_by_punctuation(text, lang):
    segmented_sentences = []
    paragraphs = text.split("\n")
    for paragraph in paragraphs:
        if paragraph.strip():
            nlp = src_nlp if lang == SRC_LANG else mt_nlp
            doc = nlp(paragraph)
            for sent in doc.sents:
                segmented_sentences.append(normalize_text(sent.text.strip()) + SEPARATOR)
    return segmented_sentences


def preprocess_sentences(sentences):
    processed = [s.replace(SEPARATOR, "").strip() for s in sentences]
    return "\n".join(processed)


def generate_overlap_and_embedding(text):
    with tempfile.NamedTemporaryFile(delete=True, mode="w+", encoding="utf-8", suffix=".txt") as txt_file:
        txt_file.write(text)
        txt_file.flush()
        txt_filename = txt_file.name
        overlaps_file = txt_filename + ".overlaps"
        embed_file = txt_filename + ".emb"

        subprocess.run(["./overlap.py", "-i", txt_filename, "-o", overlaps_file, "-n", "10"], check=True)
        subprocess.run(
            " ".join(["$LASER/tasks/embed/embed.sh", overlaps_file, embed_file]),
            shell=True, check=True,
        )

        with open(embed_file, "rb") as f:
            embeddings_content = f.read()
        with open(overlaps_file, "r", encoding="utf-8") as f:
            overlap_content = f.read()

    for fp in [overlaps_file, embed_file]:
        try:
            os.remove(fp)
        except OSError:
            pass

    return overlap_content, embeddings_content


def compute_alignment_stats(alignment_results):
    costs = []
    zero_cost_count = 0
    for entry in alignment_results:
        try:
            cost = float(entry.split(":")[-1])
            if cost == 0.0:
                zero_cost_count += 1
            else:
                costs.append(cost)
        except ValueError:
            continue
    avg_cost = sum(costs) / len(costs) if costs else 0.0
    zero_cost_ratio = zero_cost_count / len(alignment_results) if alignment_results else 0.0
    return avg_cost, zero_cost_ratio


def run_vecalign_explore(src_text, tgt_text, src_overlap, tgt_overlap, src_embed, tgt_embed):
    """Adaptive penalty search for VecAlign alignment."""
    del_percentile_frac = 0.2
    step_size = 0.005
    prev_zero_cost_ratio = None
    prev_avg_cost = None

    best_avg_cost = float("inf")
    best_del_percentile_frac = del_percentile_frac
    best_zero_cost_ratio = 0.0
    best_alignments = []

    with tempfile.NamedTemporaryFile(delete=True, mode="w+", encoding="utf-8", suffix=".txt") as src_file, \
         tempfile.NamedTemporaryFile(delete=True, mode="w+", encoding="utf-8", suffix=".txt") as tgt_file, \
         tempfile.NamedTemporaryFile(delete=True, mode="w+", encoding="utf-8", suffix=".overlaps") as src_overlap_file, \
         tempfile.NamedTemporaryFile(delete=True, mode="w+", encoding="utf-8", suffix=".overlaps") as tgt_overlap_file, \
         tempfile.NamedTemporaryFile(delete=True, mode="wb", suffix=".emb") as src_embed_file, \
         tempfile.NamedTemporaryFile(delete=True, mode="wb", suffix=".emb") as tgt_embed_file:

        src_file.write(src_text); src_file.flush()
        tgt_file.write(tgt_text); tgt_file.flush()
        src_overlap_file.write(src_overlap); src_overlap_file.flush()
        tgt_overlap_file.write(tgt_overlap); tgt_overlap_file.flush()
        src_embed_file.write(src_embed); src_embed_file.flush()
        tgt_embed_file.write(tgt_embed); tgt_embed_file.flush()

        while del_percentile_frac > 0:
            result = subprocess.run(
                [
                    "./vecalign.py",
                    "--alignment_max_size", "8",
                    "--del_percentile_frac", str(del_percentile_frac),
                    "--src", src_file.name,
                    "--tgt", tgt_file.name,
                    "--src_embed", src_overlap_file.name, src_embed_file.name,
                    "--tgt_embed", tgt_overlap_file.name, tgt_embed_file.name,
                ],
                stdout=subprocess.PIPE, text=True,
            )

            output_lines = result.stdout.strip().split("\n")
            avg_cost, zero_cost_ratio = compute_alignment_stats(output_lines)
            print(f"del_percentile_frac: {del_percentile_frac:.3f} | Avg Cost: {avg_cost:.6f} | Zero-Cost Ratio: {zero_cost_ratio:.2%}")

            if prev_zero_cost_ratio is not None and prev_zero_cost_ratio != 0 and (zero_cost_ratio / prev_zero_cost_ratio) > 1.5:
                break
            elif prev_zero_cost_ratio is not None and (
                (zero_cost_ratio - prev_zero_cost_ratio) > 0.15
                or avg_cost > prev_avg_cost
                or avg_cost < 0.3
                or zero_cost_ratio > 0.7
            ):
                break
            else:
                if avg_cost < best_avg_cost:
                    best_avg_cost = avg_cost
                    best_del_percentile_frac = del_percentile_frac
                    best_zero_cost_ratio = zero_cost_ratio
                    best_alignments = output_lines

            prev_zero_cost_ratio = zero_cost_ratio
            prev_avg_cost = avg_cost
            del_percentile_frac -= step_size

    parsed_alignments = []
    for line in best_alignments:
        if line:
            src_part, tgt_part, _ = line.split(":")
            src_indices = list(map(int, src_part.strip("[]").split(","))) if src_part.strip("[]") else []
            tgt_indices = list(map(int, tgt_part.strip("[]").split(","))) if tgt_part.strip("[]") else []
            parsed_alignments.append((src_indices, tgt_indices))

    print(f"\nBest: del_percentile_frac={best_del_percentile_frac:.3f}, Avg Cost={best_avg_cost:.6f}, Zero-Cost Ratio={best_zero_cost_ratio:.2%}")
    return parsed_alignments


# ---------------------------------------------------------------------------
# Sentence / window helpers
# ---------------------------------------------------------------------------
def clean_sentence(sentence):
    if not sentence:
        return ""
    parts = sentence.split(SEPARATOR)
    unique_parts = list(dict.fromkeys(part.strip() for part in parts if part.strip()))
    return f" {SEPARATOR} ".join(unique_parts) + f" {SEPARATOR}"


def sliding_windows(sentences, window_size):
    windows = []
    for i in range(len(sentences) - window_size + 1):
        window = [clean_sentence(s) for s in sentences[i : i + window_size]]
        unique_window = list(dict.fromkeys(window))
        windows.append(unique_window)
    return windows


def save_windows_to_file(paragraph_id, aligned_src, aligned_ref, aligned_mt,
                         src_windows, ref_windows, mt_windows, output_dir, output_name):
    os.makedirs(output_dir, exist_ok=True)

    windows_data = {
        "paragraph_id": paragraph_id,
        "src_windows": src_windows,
        "ref_windows": ref_windows,
        "mt_windows": mt_windows,
    }
    with open(os.path.join(output_dir, f"windows_{paragraph_id}_{output_name}.json"), "w", encoding="utf-8") as f:
        json.dump(windows_data, f, ensure_ascii=False, indent=2)

    aligned_info = {"src": aligned_src, "ref": aligned_ref, "mt": aligned_mt}
    with open(os.path.join(output_dir, f"aligned_{paragraph_id}_{output_name}.json"), "w", encoding="utf-8") as f:
        json.dump(aligned_info, f, ensure_ascii=False, indent=2)


# ---------------------------------------------------------------------------
# Alignment gap processing
# ---------------------------------------------------------------------------
def process_gaps(alignments):
    new_alignments = []
    gap_counts = {}
    n = len(alignments)
    i = 0
    while i < n:
        src, tgt = alignments[i]
        if not src and tgt:
            block = []
            while i < n and not alignments[i][0] and alignments[i][1]:
                block.append(alignments[i])
                i += 1
            left_src = new_alignments[-1][0][-1] if new_alignments and new_alignments[-1][0] else None
            right_src = None
            j = i
            while j < n:
                if alignments[j][0]:
                    right_src = alignments[j][0][0]
                    break
                j += 1
            gap_key = left_src if left_src is not None else (right_src - 1 if right_src is not None else 0)
            for item in block:
                new_alignments.append(([-gap_key], item[1]))
            gap_counts[gap_key] = gap_counts.get(gap_key, 0) + len(block)
        else:
            new_alignments.append(alignments[i])
            i += 1
    return new_alignments, gap_counts


def complement_gaps(processed, gap_counts, desired_gaps):
    all_keys = set(gap_counts.keys()) | set(desired_gaps.keys())
    for gap in all_keys:
        current = gap_counts.get(gap, 0)
        desired = desired_gaps.get(gap, 0)
        if current < desired:
            indices = [i for i, (src, _) in enumerate(processed) if src and src[0] == -gap]
            insert_idx = indices[0] if indices else next(
                (i for i, (src, _) in enumerate(processed) if src and src[0] > gap), len(processed)
            )
            for _ in range(desired - current):
                processed.insert(insert_idx, ([-gap], []))
            gap_counts[gap] = desired
    return processed


def custom_sort_key(item):
    src, _ = item
    if src:
        val = src[0]
        return (val, 0) if val >= 0 else (abs(val), 1)
    return (float("inf"), 2)


def fill_empty_alignments(src_ref_alignments, src_mt_alignments):
    proc_ref, gaps_ref = process_gaps(src_ref_alignments)
    proc_mt, gaps_mt = process_gaps(src_mt_alignments)
    proc_ref = complement_gaps(proc_ref, gaps_ref, gaps_mt)
    proc_mt = complement_gaps(proc_mt, gaps_mt, gaps_ref)
    proc_ref.sort(key=custom_sort_key)
    proc_mt.sort(key=custom_sort_key)
    return proc_ref, proc_mt


def find_common_alignments(src_ref_alignments, src_mt_alignments):
    common_alignments = []
    src_ref_alignments, src_mt_alignments = fill_empty_alignments(src_ref_alignments, src_mt_alignments)

    for ref_align in src_ref_alignments:
        for mt_align in src_mt_alignments:
            common_src = sorted(list(set(ref_align[0]) & set(mt_align[0])))
            if common_src:
                common_ref = sorted(list(set(ref_align[1]))) if ref_align[1] else [-1]
                common_mt = sorted(list(set(mt_align[1]))) if mt_align[1] else [-1]
                common_alignments.append((common_src, common_ref, common_mt))

    unique = []
    seen = set()
    for triple in common_alignments:
        key = (tuple(triple[0]), tuple(triple[1]), tuple(triple[2]))
        if key not in seen:
            seen.add(key)
            unique.append(triple)
    return unique


# ---------------------------------------------------------------------------
# COMET scoring
# ---------------------------------------------------------------------------
def compute_comet_scores(src_windows, ref_windows, mt_windows, paragraph_id, mt_col):
    """Compute COMET scores for aligned sliding windows."""
    zero_score_windows = []

    with tempfile.NamedTemporaryFile(mode="w+", delete=True) as src_file, \
         tempfile.NamedTemporaryFile(mode="w+", delete=True) as ref_file, \
         tempfile.NamedTemporaryFile(mode="w+", delete=True) as mt_file:

        for idx, (src_win, ref_win, mt_win) in enumerate(zip(src_windows, ref_windows, mt_windows)):
            src_line = " ".join(src_win)
            ref_line = " ".join(ref_win)
            mt_line = " ".join(mt_win)
            if src_line and mt_line:
                src_file.write(src_line + "\n")
                ref_file.write(ref_line + "\n")
                mt_file.write(mt_line + "\n")
            else:
                zero_score_windows.append(idx)

        src_file.flush()
        ref_file.flush()
        mt_file.flush()

        comet_command = [
            "comet-score",
            "-s", src_file.name,
            "-t", mt_file.name,
            "-r", ref_file.name,
            "--model", COMET_MODEL,
            "--enable-context",
            "--gpus", GPU_ID,
            "--quiet",
        ]
        result = subprocess.run(comet_command, stdout=subprocess.PIPE, text=True)
        comet_scores = [float(s) for s in re.findall(r"score:\s(-?[0-9.]+)", result.stdout.strip())][:-1]

    for idx in zero_score_windows:
        comet_scores.insert(idx, 0.0)

    scores_data = {
        "paragraph_id": paragraph_id,
        "comet_scores": comet_scores,
        "windows_length": len(comet_scores),
        "windows_zero_ratio": len(zero_score_windows) / len(comet_scores) if comet_scores else 0,
        "avg_comet": sum(comet_scores) / len(comet_scores) if comet_scores else 0,
    }

    scores_file = os.path.join(SAVE_FOLDER, "scores", f"scores_{paragraph_id}_{mt_col}.json")
    os.makedirs(os.path.dirname(scores_file), exist_ok=True)
    with open(scores_file, "w", encoding="utf-8") as f:
        json.dump(scores_data, f, ensure_ascii=False, indent=2)

    return scores_data


# ---------------------------------------------------------------------------
# Paragraph-level processing
# ---------------------------------------------------------------------------
def paragraph_level_score(row, paragraph_id, src_col=None, ref_col=None, mt_col=None):
    global mt_nlp, src_nlp

    if ref_col is None:
        ref_col = LANG
    if mt_col is None:
        mt_col = TARGET

    src_sentences = segment_sentences_by_punctuation(row[src_col], src_col)
    ref_sentences = segment_sentences_by_punctuation(row[ref_col], ref_col)
    mt_sentences = segment_sentences_by_punctuation(row[mt_col], ref_col)

    src_txt = preprocess_sentences(src_sentences)
    ref_txt = preprocess_sentences(ref_sentences)
    mt_txt = preprocess_sentences(mt_sentences)

    src_overlap, src_embed = generate_overlap_and_embedding(src_txt)
    ref_overlap, ref_embed = generate_overlap_and_embedding(ref_txt)
    mt_overlap, mt_embed = generate_overlap_and_embedding(mt_txt)

    src_ref_alignments = run_vecalign_explore(src_txt, ref_txt, src_overlap, ref_overlap, src_embed, ref_embed)
    src_mt_alignments = run_vecalign_explore(src_txt, mt_txt, src_overlap, mt_overlap, src_embed, mt_embed)

    common_alignments = find_common_alignments(src_ref_alignments, src_mt_alignments)

    adjusted_src, adjusted_ref, adjusted_mt = [], [], []
    for src_indices, ref_indices, mt_indices in common_alignments:
        ref_indices = [x for x in ref_indices if x != -1]
        mt_indices = [x for x in mt_indices if x != -1]
        aligned_src = "" if (src_indices and src_indices[0] < 0) else " ".join([src_sentences[i] for i in src_indices])
        aligned_ref = " ".join([ref_sentences[i] for i in ref_indices]) if ref_indices else ""
        aligned_mt = " ".join([mt_sentences[i] for i in mt_indices]) if mt_indices else ""
        adjusted_src.append(aligned_src)
        adjusted_ref.append(aligned_ref)
        adjusted_mt.append(aligned_mt)

    src_windows = sliding_windows(adjusted_src, WINDOW_SIZE)
    ref_windows = sliding_windows(adjusted_ref, WINDOW_SIZE)
    mt_windows = sliding_windows(adjusted_mt, WINDOW_SIZE)

    compute_comet_scores(src_windows, ref_windows, mt_windows, paragraph_id, mt_col)

    output_dir = os.path.join(SAVE_FOLDER, "windows")
    save_windows_to_file(
        paragraph_id, adjusted_src, adjusted_ref, adjusted_mt,
        src_windows, ref_windows, mt_windows, output_dir, output_name=mt_col,
    )


def parallel_paragraph_level_score(pool_args):
    row, paragraph_id = pool_args
    try:
        paragraph_level_score(row, paragraph_id, mt_col=TARGET, src_col=SRC_LANG)
    except Exception as e:
        print(f"Error processing paragraph {paragraph_id}: {e}")


# ---------------------------------------------------------------------------
# Score aggregation
# ---------------------------------------------------------------------------
def aggregate_scores_and_merge(evaluated_file_path, save_folder, target):
    df = pd.read_csv(evaluated_file_path)
    df["comet"] = 0.0
    df["windows_zero_ratio"] = 0.0

    scores_dir = os.path.join(save_folder, "scores")
    total_comet = []
    total_zero_ratio = []

    for idx in df.index:
        score_file = os.path.join(scores_dir, f"scores_{idx}_{target}.json")
        if os.path.exists(score_file):
            with open(score_file, "r", encoding="utf-8") as f:
                scores = json.load(f)
                df.at[idx, "comet"] = scores.get("avg_comet", 0)
                df.at[idx, "windows_zero_ratio"] = scores.get("windows_zero_ratio", 0)
                total_comet.append(scores.get("avg_comet", 0))
                total_zero_ratio.append(scores.get("windows_zero_ratio", 0))

    overall_scores = {
        "comet": sum(total_comet) / len(total_comet) if total_comet else 0,
        "windows_zero_ratio": sum(total_zero_ratio) / len(total_zero_ratio) if total_zero_ratio else 0,
    }

    output_path = os.path.join(save_folder, f"evaluated_results_{target}.csv")
    df.to_csv(output_path, index=False)
    return overall_scores


# ---------------------------------------------------------------------------
# Global parameters (parsed at module load for multiprocessing compatibility)
# ---------------------------------------------------------------------------
set_seed(42)

parser = argparse.ArgumentParser(description="SEGALE Context-Level Evaluation")
parser.add_argument("--file", type=str, required=True, help="Input CSV file to evaluate.")
parser.add_argument("--target_column", type=str, required=True, help="MT column name in the CSV.")
parser.add_argument("--save", type=str, default="./eval_output", help="Output folder for scores.")
parser.add_argument("--src_language", type=str, required=True, help="Source language name (e.g., Chinese).")
parser.add_argument("--task_language", type=str, required=True, help="Target language name (e.g., English).")
parser.add_argument("--gpu_id", type=str, default="0", help="GPU ID for COMET scoring.")
parser.add_argument("--comet_model", type=str, default="Unbabel/wmt22-comet-da", help="COMET model name.")
parser.add_argument("--num_workers", type=int, default=2, help="Number of parallel workers.")

args = parser.parse_args()

TARGET = args.target_column
TASK_LANGUAGE = args.task_language
SRC_LANGUAGE = args.src_language
evaluated_file_path = args.file
WINDOW_SIZE = 3
SEPARATOR = "</s>"
SAVE_FOLDER = args.save
GPU_ID = args.gpu_id
COMET_MODEL = args.comet_model

LANG, mt_nlp = get_lang_and_nlp(TASK_LANGUAGE)[0], get_lang_and_nlp(TASK_LANGUAGE)[1]
SRC_LANG, src_nlp = get_lang_and_nlp(SRC_LANGUAGE)[0], get_lang_and_nlp(SRC_LANGUAGE)[1]

os.makedirs(SAVE_FOLDER, exist_ok=True)

print(f"TARGET: {TARGET}, SRC: {SRC_LANGUAGE} ({SRC_LANG}), TGT: {TASK_LANGUAGE} ({LANG})")
print(f"COMET model: {COMET_MODEL}, GPU: {GPU_ID}")

# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    data = pd.read_csv(evaluated_file_path)
    total = len(data)
    print(f"Evaluating {total} paragraphs (workers={args.num_workers}) ...")
    pool_args = [(row, idx) for idx, row in data.iterrows()]
    completed = 0
    with Pool(args.num_workers) as pool:
        for _ in pool.imap_unordered(parallel_paragraph_level_score, pool_args):
            completed += 1
            if completed % 10 == 0 or completed == total:
                print(f"  [{completed}/{total}] paragraphs scored ({completed/total*100:.1f}%)")
    print(f"\nAll {total} paragraphs scored. Aggregating results ...")
    overall_scores = aggregate_scores_and_merge(evaluated_file_path, SAVE_FOLDER, TARGET)
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    print(f"\n{TARGET}: {TASK_LANGUAGE} Overall scores: {overall_scores}, time: {timestamp}")
