"""
Difficulty score computation for code samples.
"""
import argparse
from typing import List, Tuple

import ray
import torch
import pandas as pd
from jinja2 import Template
from tqdm import tqdm

from utils import normalize_text


@torch.no_grad()
def _compute_loss(model, input_ids: torch.Tensor, labels: torch.Tensor = None):
    """Compute cross-entropy loss."""
    if labels is None:
        labels = input_ids

    input_ids = input_ids.to(model.device)
    labels = labels.to(model.device)

    output = model(input_ids=input_ids, labels=labels)
    return float(output.loss.item())


@torch.no_grad()
def compute_scores(model, tokenizer, prompt: str, response: str) -> Tuple[float, float]:
    """
    Compute difficulty scores.

    Args:
        model: Language model
        tokenizer: Tokenizer
        prompt: Input prompt
        response: Response text

    Returns:
        (response_loss, conditional_loss)
    """
    prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"]
    response_ids = tokenizer(response, return_tensors="pt", add_special_tokens=False)["input_ids"]

    bos = tokenizer.bos_token_id
    if bos is not None:
        bos_tensor = torch.tensor([[bos]])
        add_bos = lambda x: torch.cat([bos_tensor, x], dim=1)
    else:
        add_bos = lambda x: x

    # Response-only loss
    resp_loss = _compute_loss(model, add_bos(response_ids))

    # Conditional loss
    full = add_bos(torch.cat([prompt_ids, response_ids], dim=1))
    labels = full.clone()
    prefix_len = (1 if bos else 0) + prompt_ids.size(1)
    labels[:, :prefix_len] = -100
    cond_loss = _compute_loss(model, full, labels)

    return resp_loss, cond_loss


@ray.remote(num_gpus=1)
class ScoreWorker:
    """Ray worker for score computation."""

    def __init__(self, model_path: str, template_path: str):
        from transformers import AutoTokenizer, AutoModelForCausalLM

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        self.model.eval()
        self.template = Template(open(template_path).read())

    def compute_batch(
        self,
        instruction: str,
        codes: List[str],
        min_codes: int = 11
    ) -> Tuple[List[float], List[float]]:
        """Compute scores for a batch of codes."""
        instruction = normalize_text(instruction)

        if len(codes) <= min_codes:
            return [1.0] * len(codes), [1.0] * len(codes)

        resp_losses, cond_losses = [], []
        prefix = "Here is the correct Python program:\n"

        for code in codes:
            response = prefix + self.template.render(code=normalize_text(code, 0))
            resp, cond = compute_scores(self.model, self.tokenizer, instruction, response)
            resp_losses.append(resp)
            cond_losses.append(cond)

        return resp_losses, cond_losses


def main():
    parser = argparse.ArgumentParser(description="Compute difficulty scores")
    parser.add_argument("--input", required=True, help="Input parquet")
    parser.add_argument("--output", required=True, help="Output parquet")
    parser.add_argument("--model", required=True, help="Model path")
    parser.add_argument("--template", required=True, help="Jinja template")
    parser.add_argument("--workers", type=int, default=8, help="GPU workers")
    args = parser.parse_args()

    ray.init(ignore_reinit_error=True, num_gpus=args.workers)

    print(f"Starting {args.workers} workers...")
    workers = [ScoreWorker.remote(args.model, args.template) for _ in range(args.workers)]

    print(f"Loading: {args.input}")
    df = pd.read_parquet(args.input)

    tasks = []
    for i, row in df.iterrows():
        worker = workers[i % args.workers]
        task = worker.compute_batch.remote(row["instruction"], row["candidate"].tolist())
        tasks.append(task)

    resp_losses, cond_losses = [], []
    for task in tqdm(tasks, desc="Scoring"):
        resp, cond = ray.get(task)
        resp_losses.append(resp)
        cond_losses.append(cond)

    df["response_loss"] = resp_losses
    df["condition_loss"] = cond_losses

    print(f"Saving: {args.output}")
    df.to_parquet(args.output, index=False)
    print("Done")

    ray.shutdown()


if __name__ == "__main__":
    main()
