"""
Energy evaluation over CSV rows:
- Input CSV columns: question, answer, second_question_source_idx
- Output CSV columns: question, answer, question_related, energy_full, energy_related
"""

import argparse
import logging
import random
import sys
from pathlib import Path

import pandas as pd
import torch
import nltk
from tqdm import tqdm

# If your project expects these imports to exist, keep them; unused utilities are harmless.
from src.energy_model.config import EBMConfig
from src.energy_model.models import EnergyModel

# Optional: if these imports are required elsewhere in the package
from src.energy_model.utils.energy_network import semantic_sentence_split, normalize_sentences  # noqa: F401
from src.interpreter_model.utils.interpreter_utils import clone_encoder_from_ebm  # noqa: F401

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("CUDA devices:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())

# Ensure punkt exists for any downstream tokenization your encoders may do
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt", quiet=True)


def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description="Compute energies for question/answer pairs")
    # Required
    p.add_argument("--csv", required=True, help="Input CSV: columns question,answer,second_question_source_idx")
    p.add_argument("--ebm_checkpoint", required=True, help="Path to trained EBM checkpoint (.pt/.bin)")
    # EBM architecture knobs
    p.add_argument("--ebm_self_attention_layers", type=int, default=2)
    p.add_argument("--ebm_cross_attention_layers", type=int, default=6)
    # Output
    p.add_argument("--output_csv", default=None, help="Where to save the results CSV (default: <input>_energies.csv)")
    # Misc
    p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--loglevel", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
    p.add_argument("--debug_print", action="store_true",
               help="Print sample second_question_source_idx and question_related for first 10 rows (no energy computed)")
    p.add_argument("--src_csv", required=True,
               help="CSV file that contains the source questions (must have a column named 'question').")

    return p


@torch.inference_mode()
def compute_energy(energy_model: EnergyModel, device: torch.device, x_text: str, y_text: str) -> float:
    """
    EnergyModel.forward expects lists of strings: List[str], List[str]
    Returns a scalar float energy for the pair (x_text, y_text).
    """
    # Model handles encoding internally
    energy_tensor = energy_model([x_text], [y_text])  # shape [B=1]
    # Ensure on CPU & Python float
    return float(energy_tensor.squeeze().item())


def main():
    args = build_parser().parse_args()
    logging.basicConfig(level=getattr(logging, args.loglevel), format="%(asctime)s - %(levelname)s - %(message)s")

    set_seed(args.seed)
    torch.backends.cudnn.benchmark = True
    device = torch.device(args.device)

    # Load main CSV
    df = pd.read_csv(args.csv, sep=None, engine="python")
    df.columns = [str(c).strip() for c in df.columns]

    required_cols = {"question", "answer", "second_question_source_idx"}
    missing = required_cols - set(df.columns)
    if missing:
        raise ValueError(f"Input CSV missing required columns: {missing}")

    # Load source CSV (always required now)
    src_df = pd.read_csv(args.src_csv, sep=None, engine="python")
    src_df.columns = [str(c).strip() for c in src_df.columns]
    if "question" not in src_df.columns:
        raise ValueError(f"Source CSV {args.src_csv} must have a 'question' column. Found: {list(src_df.columns)}")

    source_questions = src_df["question"].astype(str).tolist()
    logging.info(f"Loaded {len(source_questions)} source questions from {args.src_csv}")

    
    # Debug mode: just show 10 samples of how question_related is built
    if args.debug_print:
        print("\n=== DEBUG: showing up to 10 samples ===")
        for i, row in df.head(10).iterrows():
            q = str(row["question"])
            try:
                src_idx = int(row["second_question_source_idx"])
            except Exception:
                print("goose")
                src_idx = -1

            q_related = q
            src_text = None
            if 0 <= src_idx < len(source_questions):
                src_text = source_questions[src_idx]  # <-- assign into src_text
                if src_text:
                    q_related = q.replace(src_text, "", 1).strip()
                    q_related = " ".join(q_related.split())

            print(f"Row {i}:")
            print(f"  second_question_source_idx = {src_idx}")
            print(f"  source_question_text      = {src_text}")
            print(f"  original_question         = {q}")
            print(f"  question_related          = {q_related}")
            print("-" * 50)
        return  # exit early


    # Load EBM
    logging.info(f"Loading EBM from {args.ebm_checkpoint} …")
    ebm_config = EBMConfig(
        self_attention_n_layers=args.ebm_self_attention_layers,
        cross_attention_n_layers=args.ebm_cross_attention_layers,
    )
    energy_model = EnergyModel(ebm_config).to(device)
    ckpt = torch.load(args.ebm_checkpoint, map_location=device)
    state_dict = ckpt.get("model_state_dict", ckpt)
    energy_model.load_state_dict(state_dict, strict=True)
    energy_model.eval()

    # Prep output
    out_rows = []

    # Iterate rows
    for i, row in tqdm(df.iterrows(), total=len(df), desc="Computing energies"):
        q = str(row["question"])
        a = str(row["answer"])

        # 1) Full energy: E(question, answer)
        e_full = compute_energy(energy_model, device, q, a)

        # 2) Build question_related by removing the *text* of the question
        #    at index second_question_source_idx FROM THE SOURCE CSV.
        try:
            src_idx = int(row["second_question_source_idx"])
        except Exception:
            src_idx = -1

        q_related = q
        if 0 <= src_idx < len(source_questions):
            source_q = source_questions[src_idx]
            if source_q:
                # Remove the first occurrence of that source question text
                q_related = q.replace(source_q, "", 1).strip()
                q_related = " ".join(q_related.split())  # normalize spaces

        e_related = compute_energy(energy_model, device, q_related, a)

        out_rows.append(
            {
                "question": q,
                "answer": a,
                "question_related": q_related,
                "energy_full": e_full,
                "energy_related": e_related,
            }
        )

    out_df = pd.DataFrame(out_rows)

    # Decide output path
    if args.output_csv:
        out_path = Path(args.output_csv)
    else:
        in_path = Path(args.csv)
        out_path = in_path.with_name(in_path.stem + "_energies.csv")

    out_df.to_csv(out_path, index=False, encoding="utf-8")
    logging.info(f"Saved results to: {out_path.resolve()}")

if __name__ == "__main__":
    main()