import json

import polars as pl
import torch

from tdc_fusion.models.CLIPWrapper import MolCLIPWrapper
from tdc_fusion.utils.eval import bootstrap_clf


def compute_scores(
    smiles: list[str],
    text_pos: list[str],
    text_neg: list[str],
    model: MolCLIPWrapper,
    temp: float = 1.0,
):
    o = model.batch_inference(
        smiles=smiles,
        text=text_pos + text_neg,
    )
    pos_logits, neg_logits = o.chunk(2, dim=-1)

    pos_logits = pos_logits.mean(dim=1, keepdim=True)
    neg_logits = neg_logits.mean(dim=1, keepdim=True)

    full_logits = torch.cat(
        [
            neg_logits,
            pos_logits,
        ],
        dim=1,
    )

    probs = (full_logits / temp).softmax(dim=-1)

    pos_score = probs[:, 1].cpu().numpy()

    return pos_score


clip = MolCLIPWrapper.load_from_checkpoint("./checkpoints/CLIPWrapper.ckpt", strict=False).cuda()
clip.freeze()

with open("./data/text_anchors.json") as f:
    text_anchors = json.load(f)

for task in text_anchors:
    pos_text = text_anchors[task]["text_pos"]
    neg_text = text_anchors[task]["text_neg"]

    test_df = pl.read_parquet("./data/zero-shot.parquet").filter(pl.col("prompt_name").eq(task))

    train_smiles = test_df["Drug SMILES"].to_list()
    train_labels = test_df["Y"].to_numpy()

    scores = compute_scores(
        smiles=train_smiles,
        text_pos=pos_text,
        text_neg=neg_text,
        model=clip,
    )

    auc = bootstrap_clf(
        predictions=scores,
        targets=train_labels,
    )["AUROC"]
    print(f"{task} (AUROC): {auc:.3f}")
