"""
Generate code embeddings using encoder models.
"""
import argparse
from typing import List

import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel


def compute_embeddings(
    codes: List[str],
    tokenizer,
    model,
    device: torch.device,
    max_length: int = 512
) -> List[List[float]]:
    """
    Compute CLS embeddings for code snippets.

    Args:
        codes: List of code strings
        tokenizer: Tokenizer
        model: Encoder model
        device: Computation device
        max_length: Maximum sequence length

    Returns:
        List of embedding vectors
    """
    embeddings = []
    for code in codes:
        inputs = tokenizer(
            code,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            cls = outputs.last_hidden_state[:, 0, :].cpu().numpy().tolist()[0]
            embeddings.append(cls)

    return embeddings


def main():
    parser = argparse.ArgumentParser(description="Generate code embeddings")
    parser.add_argument("--model", required=True, help="Encoder model path")
    parser.add_argument("--input", required=True, help="Input parquet")
    parser.add_argument("--output", required=True, help="Output parquet")
    parser.add_argument("--column", default="candidate", help="Code column name")
    parser.add_argument("--max-length", type=int, default=512)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    print(f"Loading model: {args.model}")
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModel.from_pretrained(args.model).to(device)
    model.eval()

    print(f"Loading data: {args.input}")
    df = pd.read_parquet(args.input)

    embeddings = []
    for codes in tqdm(df[args.column], desc="Computing embeddings"):
        embeddings.append(
            compute_embeddings(codes, tokenizer, model, device, args.max_length)
        )

    df["embedding"] = embeddings

    print(f"Saving: {args.output}")
    df.to_parquet(args.output)
    print("Done")


if __name__ == "__main__":
    main()
