import fire
import lightning as L
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from sentence_transformers import SentenceTransformer
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from tdc_fusion.data.SingleModalityDM import ModalityType, TDCTaskDataModule
from tdc_fusion.models.MolCLIP import MolCLIP
from tdc_fusion.models.MolT5 import MolT5
from tdc_fusion.models.ProtT5 import ProtT5
from tdc_fusion.tasks import TASK_TO_METRIC
from tdc_fusion.utils.eval import bootstrap_clf, bootstrap_regression

torch.set_printoptions(precision=4, sci_mode=False)
torch.set_float32_matmul_precision("medium")


class MultiTaskNet(L.LightningModule):
    def __init__(
        self,
        modality: ModalityType,
        task_map: dict[int, str],
        prompt_map: dict[int, str],
        hidden_dim: int,
    ):
        super().__init__()
        self.save_hyperparameters()

        in_dim = 0
        for char in modality.split("+"):
            if char == "S":
                in_dim += 1536
            elif char == "P":
                in_dim += 1024
            else:
                in_dim += 128

        self.chem_clip = MolCLIP.load_from_checkpoint("./checkpoints/ChemCLIP.ckpt")
        self.chem_clip.freeze()

        self.prot_clip = MolCLIP.load_from_checkpoint("./checkpoints/ProteinCLIP.ckpt")
        self.prot_clip.freeze()

        self.mol_t5 = MolT5.load_from_checkpoint("./checkpoints/MolT5.ckpt")
        self.prot_t5 = ProtT5()

        self.text_embedding = SentenceTransformer(
            "NovaSearch/stella_en_1.5B_v5", model_kwargs={"torch_dtype": "bfloat16"}
        )

        self.ntasks = len(prompt_map)
        self.task_map = task_map
        self.prompt_map = prompt_map

        self.hidden_dim = hidden_dim

        self.oproj = nn.Sequential(
            nn.Linear(in_dim, self.hidden_dim),
            nn.SiLU(),
            nn.Linear(self.hidden_dim, self.ntasks),
        )

    def forward(
        self,
        SMILES: list[list[str]] | None = None,
        PROTEINS: list[list[str]] | None = None,
        CELL_LINE: list[list[str]] | None = None,
        DISEASE: list[list[str]] | None = None,
    ) -> Tensor:
        with torch.no_grad():
            smiles = self._encode_smiles(SMILES) if SMILES is not None else None
            proteins = self._encode_proteins(PROTEINS) if PROTEINS is not None else None
            cell_lines = self._encode_text(CELL_LINE) if CELL_LINE is not None else None
            disease = self._encode_text(DISEASE) if DISEASE is not None else None

        full_cat = [x.flatten(1) for x in [smiles, proteins, cell_lines, disease] if x is not None]
        inputs = torch.cat(full_cat, dim=1)

        preds = self.oproj(inputs)

        return preds

    def _encode_proteins(self, PROTEINS: list[list[str]]):
        tensors = torch.stack([self.prot_t5.encode_proteins(prot, disable_pbar=False) for prot in PROTEINS], dim=1).to(
            self.device,
            self.dtype,  # type: ignore
        )
        tensors = self.prot_clip.project_mol(tensors, featureize_mol=True)
        return tensors

    def _encode_smiles(self, SMILES: list[list[str]]):
        tensors = torch.stack([self.mol_t5.encode_smiles(smi, disable_pbar=True) for smi in SMILES], dim=1).to(
            self.device,
            self.dtype,  # type: ignore
        )
        tensors = self.chem_clip.project_mol(tensors, featureize_mol=True)
        return tensors

    def _encode_text(self, texts: list[list[str]]):
        tensors = torch.stack([self.text_embedding.encode(text, convert_to_tensor=True) for text in texts], dim=1).to(
            self.device,
            self.dtype,  # type: ignore
        )
        tensors = self.chem_clip.project_text(tensors)
        return tensors


def loss_step(
    batch: dict[str, Tensor],
    model: MultiTaskNet,
    is_clf: bool,
):
    preds = model(
        SMILES=batch["SMILES"],
        PROTEINS=batch["PROTEINS"],
        CELL_LINE=batch["CELL_LINE"],
        DISEASE=batch["DISEASE"],
    )

    y = batch["y"]
    logit_id = batch["logit_id"]

    preds = preds[torch.arange(len(preds)), logit_id]

    if is_clf:  # noqa: SIM108
        loss = F.binary_cross_entropy_with_logits(preds, y)
    else:
        loss = F.l1_loss(preds, y)

    return loss, preds


@torch.inference_mode()
def validate(model: MultiTaskNet, val_loader: DataLoader, is_clf: bool):
    preds = []
    tasks = []
    labels_unscaled = []

    losses = []
    for batch in tqdm(  # noqa: B007
        val_loader, leave=False, desc="Validating"
    ):
        loss, preds_batch = loss_step(batch, model, is_clf=is_clf)

        task_id = batch["task_id"]
        y_unscaled = batch["y_unscaled"]

        losses.append(loss.item())

        if is_clf:
            preds.append(preds_batch.sigmoid().cpu())
        else:
            preds.append(preds_batch.cpu())

        tasks.append(task_id.cpu())
        labels_unscaled.append(y_unscaled.cpu())

    preds = torch.cat(preds, dim=0)
    tasks = torch.cat(tasks, dim=0)
    labels_unscaled = torch.cat(labels_unscaled, dim=0)
    df = pl.DataFrame(
        [
            pl.Series("task_id", tasks.numpy()),
            pl.Series("preds", preds.numpy()),
            pl.Series("labels_unscaled", labels_unscaled.numpy()),
        ]
    )

    results = []
    for task_id, task_name in model.task_map.items():
        subset = df.filter(pl.col("task_id").eq(task_id))
        if len(subset) == 0:
            continue

        metric_name = TASK_TO_METRIC[task_name]  # type: ignore

        if is_clf:
            preds = subset["preds"].to_numpy()
            labels = subset["labels_unscaled"].to_numpy()

            metric = bootstrap_clf(predictions=preds, targets=labels, n_jobs=16)[metric_name]
            results.append({"task": task_name, "score": metric, "metric": metric_name})
        else:
            preds = subset["preds"].to_numpy()
            unscaled_targets = subset["labels_unscaled"].to_numpy()

            # Z-scored, so we need to unscale
            unsm = unscaled_targets.mean()
            unsd = unscaled_targets.std()
            preds = preds * unsd + unsm

            metric = bootstrap_regression(predictions=preds, targets=unscaled_targets, n_jobs=16)[metric_name]
            results.append({"task": task_name, "score": metric, "metric": metric_name})

    df = pl.from_dicts(results).sort("task")
    total_loss = sum(losses) / len(losses)

    return df, total_loss


def eval_modality(
    modality: ModalityType,
    clf: bool = True,
    batch_size: int = 4096,
):
    accelerator = Accelerator()
    # --- Data ---
    dm = TDCTaskDataModule(
        modality=modality,
        clf=clf,
        batch_size=batch_size,
    )
    dm.setup("")

    val_loader = dm.val_dataloader()

    ckpt_path = f"{modality}_{'clf' if clf else 'reg'}"

    model = MultiTaskNet(
        modality=modality,
        task_map=dm.task_map,
        prompt_map=dm.prompt_map,
        hidden_dim=512 if modality != "P+P+P" else 1,
    )
    # Load adapter
    model.load_state_dict(torch.load("./checkpoints/adapters.pt")[ckpt_path], strict=False)
    model.freeze()

    model, val_loader = accelerator.prepare(model, val_loader)

    val_df, _ = validate(model, val_loader, is_clf=clf)
    with pl.Config(tbl_rows=-1, tbl_cols=-1, float_precision=3):
        print(val_df)


def main():
    for modality, is_clf in [
        ("P+D", False),
        ("P+P", True),
        ("P+P", False),
        ("P+P+P", False),
        ("S+S+S", False),
        ("S", True),
        ("S", False),
        ("S+C", False),
        ("S+P", False),
        ("S+S+C", False),
    ]:
        eval_modality(
            modality=modality,  # type: ignore
            clf=is_clf,
        )


if __name__ == "__main__":
    fire.Fire(main)
