from typing import Literal

import datasets
import lightning as L
import polars as pl
import torch
from torch.utils.data import DataLoader, Dataset

ModalityType = Literal["P+D", "P+P", "P+P+P", "S", "S+C", "S+D", "S+P", "S+S+C", "S+S+S"]
MODALITY_TO_COLS: dict[ModalityType, tuple[str, ...]] = {
    "P+D": ("PROT1_ID", "Disease_ID"),
    "P+P": ("PROT1_ID", "PROT2_ID"),
    "P+P+P": ("PROT1_ID", "PROT2_ID", "PROT3_ID"),
    "S": ("SMI1_ID",),
    "S+C": ("SMI1_ID", "CellLine_ID"),
    "S+D": ("SMI1_ID", "Disease_ID"),
    "S+P": ("SMI1_ID", "PROT1_ID"),
    "S+S+C": ("SMI1_ID", "SMI2_ID", "CellLine_ID"),
    "S+S+S": ("SMI1_ID", "SMI2_ID", "SMI3_ID"),
}


class TDCTaskDataset(Dataset):
    def __init__(
        self,
        modality: ModalityType,
        clf: bool,
    ):
        super().__init__()
        self.modality_cols = MODALITY_TO_COLS[modality]

        modal_str = modality.replace("+", "_") + "_test"

        df = pl.read_parquet(f"./data/{modal_str}.parquet").filter(pl.col("is_clf").eq(clf))

        self.df = df.with_columns(  # type: ignore
            pl.col("task_id").rank("dense") - 1,
            pl.col("prompt_id").rank("dense") - 1,
        )
        self.examples = datasets.Dataset.from_polars(self.df)  # type: ignore

        self.examples: datasets.Dataset = datasets.Dataset.from_polars(self.df)
        self.examples.set_format(type="torch")

    def __len__(self):
        return len(self.examples)

    def _get_inputs(self, rows: dict[str, list]):
        inputs = {
            "SMILES": [],
            "PROTEINS": [],
            "CELL_LINE": [],
            "DISEASE": [],
        }
        for modality in self.modality_cols:
            name = modality.split("_")[0]
            if "SMI" in modality:
                inputs["SMILES"].append(rows[name])
            elif "PROT" in modality:
                inputs["PROTEINS"].append(rows[name])
            elif "CellLine" in modality:
                inputs["CELL_LINE"].append(rows[name])
            elif "Disease" in modality:
                inputs["DISEASE"].append(rows[name])
            else:
                raise ValueError(f"Unknown modality: {modality}")

        outputs = collate(inputs)

        return outputs

    def __getitems__(self, indices: list[int]):
        rows = self.examples[indices]
        modality_input = self._get_inputs(rows)

        task_id = rows["task_id"]
        logit_id = rows["prompt_id"]
        y = rows["Y_scaled"]
        y_unscaled = rows["Y"]

        is_clf = rows["is_clf"].bool()

        return {
            **modality_input,
            "y": y,
            "y_unscaled": y_unscaled,
            "logit_id": logit_id,
            "task_id": task_id,
            "is_clf": is_clf,
        }


class TDCTaskDataModule(L.LightningDataModule):
    def __init__(self, clf: bool, modality: ModalityType, batch_size: int = 32):
        super().__init__()
        self.save_hyperparameters()

        self.test_ds = TDCTaskDataset(modality=modality, clf=clf)

        tasks = self.test_ds.df.select("task_name", "task_id").unique()
        prompts = self.test_ds.df.select("prompt_name", "prompt_id").unique()

        self.task_map = {row["task_id"]: row["task_name"] for row in tasks.iter_rows(named=True)}
        self.prompt_map = {row["prompt_id"]: row["prompt_name"] for row in prompts.iter_rows(named=True)}

        self.batch_size = batch_size

    def val_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            num_workers=8,
            pin_memory=True,
            collate_fn=_collate,  # type: ignore
        )


def _collate(batch: dict):
    return batch


def collate(inputs: dict[str, list[torch.Tensor | list[str]]]):
    # Collate the inputs into a single batch
    batch = {}
    for key, value in inputs.items():
        if len(value) == 0:
            batch[key] = None
        elif isinstance(value[0], torch.Tensor):
            batch[key] = torch.stack(value, dim=1)  # type: ignore
        elif isinstance(value[0], list):
            batch[key] = value
        else:
            raise ValueError(f"Unknown input type: {type(value[0])}")

    return batch
