import functools
import logging
import os
import re
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path

import hydra
import numpy as np
import torch
import torch.nn as nn
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint, RichModelSummary, RichProgressBar
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from omegaconf import DictConfig
from sklearn.model_selection import StratifiedGroupKFold
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Subset, random_split
from transformers import AutoTokenizer

import wandb
from research.wsl_ece.metric.dataloader import PUDataLoader
from research.wsl_ece.metric.dataset import PositiveUnlabeledDatasets, SizedDataset, load_dataset
from research.wsl_ece.metric.ddi2013 import SPECIAL_TOKENS
from research.wsl_ece.metric.experiment_util import (
    init_seed,
    set_all_loggers_level,
    setup_device_config,
    setup_logger,
)
from research.wsl_ece.metric.loss import LossFunction
from research.wsl_ece.metric.model import MLP, RBertClassifier, ResNet18
from research.wsl_ece.metric.pl_module import accumulate_predictions
from research.wsl_ece.metric.pu_module import PUModule
from research.wsl_ece.metric.supervised_module import SupervisedModule

logger = setup_logger(__name__)


# Set tokenizers parallelism to false to avoid forking warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"


@dataclass
class ExperimentConfig:
    """
    A class to handle command-line arguments for training of PU learning and supervised learning.
    Attributes:
        dataset (str): The name of the dataset to be used. Default is "mnist".
        root (str): The root directory where the dataset is stored. Default is "./dataset".
        num_positive (int): The number of positive samples. Default is 10000.
        train_batch_size (int): The batch size for training. Default is 64.
        test_batch_size (int): The batch size for testing. Default is 10000.
        max_epochs (int): The maximum number of epochs for training. Default is 1.
        seed (int): The random seed for reproducibility. Default is 42.
        log_level (str): The logging level. Default is "info".
        binning_method (str): The method for binning. Default is "umb".
        learning_type (str): The type of learning - "pu" for PU learning or "supervised" for supervised learning.
    Methods:
        parse_args():
            Parses command-line arguments and returns them as a namespace.
    """

    # Configurations that can also be used for identifying the classifier
    dataset: str
    classifier: nn.Module | None = None  # For legacy logging purposes
    num_positive: int = 10000
    train_batch_size: int = 256
    max_epochs: int = 1
    lr: float = 0.001
    loss_function: LossFunction = LossFunction.SIGMOID
    predict_probability: bool = False
    seed: int = 42
    learning_type: str = "pu"  # "pu" or "supervised"
    balanced_error: bool = False  # Whether to use balanced error for PU learning
    # Other dataset and model configuration
    validation_batch_size: int = 10000
    test_batch_size: int = 10000
    work_dir: str | None = None
    root: str = "./dataset"
    # General configuration
    log_level: str = "info"
    enable_progress_bar: bool = True

    @property
    def train_setup_dict(self):
        """
        Returns a dictionary containing the train task setup parameters.
        This includes the dataset name, number of positive samples,
        batch size, maximum epochs, random seed, loss function, learning rate, and learning type.
        Returns:
            dict: A dictionary containing the train task setup parameters.
        """
        # For legacy support, we return balanced_error only for ddi2013
        if self.dataset == "ddi2013":
            return {
                "dataset": self.dataset,
                "classifier": str(self.classifier)[:20],
                "num_positive": self.num_positive,
                "train_batch_size": self.train_batch_size,
                "max_epochs": self.max_epochs,
                "lr": self.lr,
                "loss_function": self.loss_function.value,
                "predict_probability": self.predict_probability,
                "balanced_error": self.balanced_error,
                "seed": self.seed,
                "learning_type": self.learning_type,
            }
        else:
            return {
                "dataset": self.dataset,
                "classifier": str(self.classifier)[:20],
                "num_positive": self.num_positive,
                "train_batch_size": self.train_batch_size,
                "max_epochs": self.max_epochs,
                "lr": self.lr,
                "loss_function": self.loss_function.value,
                "predict_probability": self.predict_probability,
                "seed": self.seed,
                "learning_type": self.learning_type,
            }


def select_tokenizer(dataset_name: str):
    """
    Selects the tokenizer based on the dataset name.

    Args:
        dataset_name (str): The name of the dataset.

    Returns:
        tokenizer: The selected tokenizer instance.
    """
    if dataset_name == "ddi2013":
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})
        return tokenizer
    else:
        return None


def select_classifier(dataset_name: str, tokenizer=None) -> nn.Module:
    """
    Selects the classifier based on the dataset name.

    Args:
        dataset_name (str): The name of the dataset.

    Returns:
        type[nn.Module]: The selected classifier class.
    """
    if dataset_name == "mnist":
        return MLP(dim=28 * 28, hidden_layer_sizes=[300, 300])
    elif dataset_name == "cifar10":
        return ResNet18()
    elif dataset_name == "ddi2013":
        return RBertClassifier(tokenizer=tokenizer)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")


def collate_pad(batch, pad_id=0):
    """
    Collate function using PyTorch's built-in padding utilities.
    """
    if not batch:
        raise ValueError("Empty batch passed to collate_pad function")

    # Separate features and labels
    features, labels = zip(*batch, strict=False)
    labels = torch.tensor(labels, dtype=torch.long)

    # Extract sequences for each key
    input_ids = [f["input_ids"] for f in features]
    attention_masks = [f["attention_mask"] for f in features]
    token_type_ids = [f["token_type_ids"] for f in features]

    # Use torch.nn.utils.rnn.pad_sequence for efficient padding

    data = {
        "input_ids": pad_sequence(input_ids, batch_first=True, padding_value=pad_id),
        "attention_mask": pad_sequence(attention_masks, batch_first=True, padding_value=pad_id),
        "token_type_ids": pad_sequence(token_type_ids, batch_first=True, padding_value=pad_id),
    }

    return data, labels


def select_collator(dataset_name: str, pad_id: int):
    if dataset_name == "ddi2013":
        return functools.partial(collate_pad, pad_id=pad_id)
    else:
        return torch.utils.data.default_collate


def pu_split_by_document_stratified(
    pu_datasets: PositiveUnlabeledDatasets, val_ratio: float = 0.1
) -> tuple[list[int], list[int], list[int], list[int]]:
    """
    Split a dataset by document ID to avoid data leakage using StratifiedGroupKFold.

    This approach ensures both stratification (maintaining class distribution) and
    grouping (no document leakage between train and validation).

    Args:
        dataset: Dataset containing samples with meta information including doc_id
        val_ratio: Ratio of samples to use for validation (approximate)

    Returns:
        tuple: (train_indices, val_indices) lists of indices for train and validation sets
    """
    # Extract labels and groups (doc_ids) for all samples
    labels = []
    groups = []

    for idx in range(len(pu_datasets.positive)):
        sample = pu_datasets.positive[idx]
        # Handle both tuple format (features, label) and direct access
        if isinstance(sample, tuple):
            features, _ = sample
            doc_id = features["meta"]["doc_id"]
        else:
            doc_id = sample["meta"]["doc_id"]  # type: ignore

        labels.append(1)  # Positive label
        groups.append(doc_id)

    for idx in range(len(pu_datasets.unlabeled)):
        sample = pu_datasets.unlabeled[idx]
        # Handle both tuple format (features, label) and direct access
        if isinstance(sample, tuple):
            features, _ = sample
            doc_id = features["meta"]["doc_id"]
        else:
            doc_id = sample["meta"]["doc_id"]  # type: ignore

        labels.append(0)  # Unlabeled (assumed negative) label
        groups.append(doc_id)

    label_array = np.asarray(labels)
    group_array = np.asarray(groups)

    # Calculate number of splits needed to approximate the desired validation ratio
    n_splits = max(2, int(1 / val_ratio))

    # Use StratifiedGroupKFold to create stratified splits while respecting document boundaries
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True)

    # Get the first split (we only need one train/val split)
    X = np.arange(len(pu_datasets.positive) + len(pu_datasets.unlabeled))  # Dummy feature array
    train_indices, val_indices = next(sgkf.split(X, label_array, group_array))
    # Separate train and validation indices for positive and unlabeled datasets
    pos_count = len(pu_datasets.positive)
    unlabeled_count = len(pu_datasets.unlabeled)

    train_pos_indices = [i for i in train_indices if i < pos_count]
    train_unlabeled_indices = [i - pos_count for i in train_indices if i >= pos_count]
    val_pos_indices = [i for i in val_indices if i < pos_count]
    val_unlabeled_indices = [i - pos_count for i in val_indices if i >= pos_count]

    # Validate indices to prevent out-of-bounds errors
    if train_unlabeled_indices:
        min_idx, max_idx = min(train_unlabeled_indices), max(train_unlabeled_indices)
        if min_idx < 0 or max_idx >= unlabeled_count:
            raise ValueError(
                f"Invalid train unlabeled indices: range=[{min_idx}, {max_idx}], unlabeled_count={unlabeled_count}"
            )
    if val_unlabeled_indices:
        min_idx, max_idx = min(val_unlabeled_indices), max(val_unlabeled_indices)
        if min_idx < 0 or max_idx >= unlabeled_count:
            raise ValueError(
                f"Invalid val unlabeled indices: range=[{min_idx}, {max_idx}], unlabeled_count={unlabeled_count}"
            )
    if train_pos_indices:
        min_idx, max_idx = min(train_pos_indices), max(train_pos_indices)
        if min_idx < 0 or max_idx >= pos_count:
            raise ValueError(f"Invalid train positive indices: range=[{min_idx}, {max_idx}], pos_count={pos_count}")
    if val_pos_indices:
        min_idx, max_idx = min(val_pos_indices), max(val_pos_indices)
        if min_idx < 0 or max_idx >= pos_count:
            raise ValueError(f"Invalid val positive indices: range=[{min_idx}, {max_idx}], pos_count={pos_count}")

    # Ensure we have non-empty splits for training
    if not train_pos_indices and not train_unlabeled_indices:
        raise ValueError("Both train_pos_indices and train_unlabeled_indices are empty")
    if not val_pos_indices and not val_unlabeled_indices:
        raise ValueError("Both val_pos_indices and val_unlabeled_indices are empty")
    return train_pos_indices, val_pos_indices, train_unlabeled_indices, val_unlabeled_indices


def supervised_split_by_document_stratified(
    dataset: SizedDataset, val_ratio: float = 0.1
) -> tuple[list[int], list[int]]:
    """
    Split a dataset by document ID to avoid data leakage using StratifiedGroupKFold.

    This approach ensures both stratification (maintaining class distribution) and
    grouping (no document leakage between train and validation).

    Args:
        dataset: Dataset containing samples with meta information including doc_id
        val_ratio: Ratio of samples to use for validation (approximate)
    Returns:
        tuple: (train_indices, val_indices) lists of indices for train and validation sets
    """
    # Extract labels and groups (doc_ids) for all samples
    labels = []
    groups = []

    for idx in range(len(dataset)):
        sample = dataset[idx]
        # Handle both tuple format (features, label) and direct access
        if isinstance(sample, tuple):
            features, label = sample
            doc_id = features["meta"]["doc_id"]
        else:
            label = sample[1]  # type: ignore
            doc_id = sample["meta"]["doc_id"]  # type: ignore

        labels.append(label)
        groups.append(doc_id)

    label_array = np.asarray(labels)
    group_array = np.asarray(groups)

    # Calculate number of splits needed to approximate the desired validation ratio
    n_splits = max(2, int(1 / val_ratio))

    # Use StratifiedGroupKFold to create stratified splits while respecting document boundaries
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True)

    # Get the first split (we only need one train/val split)
    X = np.arange(len(dataset))  # Dummy feature array
    train_indices, val_indices = next(sgkf.split(X, label_array, group_array))

    # Validate indices to prevent out-of-bounds errors
    if len(train_indices):
        min_idx, max_idx = min(train_indices), max(train_indices)
        if min_idx < 0 or max_idx >= len(dataset):
            raise ValueError(f"Invalid train indices: range=[{min_idx}, {max_idx}], dataset_size={len(dataset)}")
    if len(val_indices):
        min_idx, max_idx = min(val_indices), max(val_indices)
        if min_idx < 0 or max_idx >= len(dataset):
            raise ValueError(f"Invalid val indices: range=[{min_idx}, {max_idx}], dataset_size={len(dataset)}")

    # Ensure we have non-empty splits for training
    if not len(train_indices):
        raise ValueError("train_indices is empty")
    if not len(val_indices):
        raise ValueError("val_indices is empty")
    return train_indices.tolist(), val_indices.tolist()


def prepare_pu_data(dataset: str, root: str, num_positive: int, tokenizer=None):
    """
    Prepares the PU data for training.

    Args:
        dataset (str): The name of the dataset to load.
        root (str): The root directory where the dataset is stored.
        num_positive (int): The number of positive samples to extract.

    Returns:
        dict: A dictionary containing the training, and test datasets.
            - "full_train": A dictionary with keys "positive" and "unlabeled" for the full training data.
            - "train": A dictionary with keys "positive" and "unlabeled" for the training data.
            - "validation": A dictionary with keys "positive" and "unlabeled" for the validation data.
            - "test": The test dataset.
        float: The prior probability of the positive class in the training data.
    """
    train_data = load_dataset(dataset, root=root, train=True, tokenizer=tokenizer)
    test_data = load_dataset(dataset, root=root, train=False, tokenizer=tokenizer)
    train_data.print_dataset_info()

    # Extract positive data (label=1) and unlabeled data
    train_pu_datasets = PositiveUnlabeledDatasets.from_dataset(train_data, num_positive=num_positive)
    prior = train_pu_datasets.prior

    # Split data into train and validation sets
    train_positive_data: Subset | SizedDataset
    train_unlabeled_data: Subset | SizedDataset
    validation_positive_data: Subset | SizedDataset
    validation_unlabeled_data: Subset | SizedDataset

    # Split the data for train and validation
    if dataset == "ddi2013":
        # Document-aware stratified splitting for DDI2013 to avoid data leakage
        positive_train_indices, positive_val_indices, unlabeled_train_indices, unlabeled_val_indices = (
            pu_split_by_document_stratified(train_pu_datasets, val_ratio=0.1)
        )

        # Create Subset objects, ensuring non-empty indices
        if not positive_train_indices:
            raise ValueError("No positive samples in training set after document-stratified split")
        if not unlabeled_train_indices:
            raise ValueError("No unlabeled samples in training set after document-stratified split")
        if not positive_val_indices:
            raise ValueError("No positive samples in validation set after document-stratified split")
        if not unlabeled_val_indices:
            raise ValueError("No unlabeled samples in validation set after document-stratified split")
        train_positive_data = Subset(train_pu_datasets.positive, positive_train_indices)
        validation_positive_data = Subset(train_pu_datasets.positive, positive_val_indices)
        train_unlabeled_data = Subset(train_pu_datasets.unlabeled, unlabeled_train_indices)
        validation_unlabeled_data = Subset(train_pu_datasets.unlabeled, unlabeled_val_indices)
    else:
        # Standard random split for other datasets
        train_positive_data, validation_positive_data = random_split(train_pu_datasets.positive, [0.9, 0.1])
        train_unlabeled_data, validation_unlabeled_data = random_split(train_pu_datasets.unlabeled, [0.9, 0.1])

    return {
        "full_train": {"positive": train_pu_datasets.positive, "unlabeled": train_pu_datasets.unlabeled},
        "train": {"positive": train_positive_data, "unlabeled": train_unlabeled_data},
        "validation": {"positive": validation_positive_data, "unlabeled": validation_unlabeled_data},
        "test": test_data,
    }, prior


def prepare_supervised_data(dataset: str, root: str, num_labeled: int, tokenizer=None):
    """
    Prepares the data for supervised learning.

    Args:
        dataset (str): The name of the dataset to load.
        root (str): The root directory where the dataset is stored.

    Returns:
        dict: A dictionary containing the training, validation, and test datasets.
            - "full_train": The full training dataset.
            - "train": The training dataset.
            - "validation": The validation dataset.
            - "test": The test dataset.
    """
    full_train_data = load_dataset(dataset, root=root, train=True, tokenizer=tokenizer)
    test_data = load_dataset(dataset, root=root, train=False, tokenizer=tokenizer)
    full_train_data.print_dataset_info()

    indices = np.random.choice(len(full_train_data), num_labeled, replace=False).tolist()
    train_data = Subset(full_train_data, indices) if num_labeled < len(full_train_data) else full_train_data
    # Split train_data into train and validation sets
    if dataset == "ddi2013":
        train_indices, val_indices = supervised_split_by_document_stratified(train_data, val_ratio=0.1)  # type: ignore
        train_split = Subset(train_data, train_indices)
        val_split = Subset(train_data, val_indices)
    else:
        train_split, val_split = random_split(train_data, [0.9, 0.1])

    return {
        "full_train": full_train_data,
        "train": train_split,
        "validation": val_split,
        "test": test_data,
    }


def prepare_pu_dataloaders(
    conf: ExperimentConfig, data: dict, collator: Callable | None = None
) -> tuple[PUDataLoader, PUDataLoader, PUDataLoader, DataLoader]:
    """
    Prepares the data loaders for training, validation, and testing datasets.
    Args:
        conf (ExperimentConfig): The configuration object containing parameters for data loading.
        data (dict): A dictionary containing the datasets. Expected keys are
            "full_train", "train", "validation", and "test". Each key should map
            to another dictionary with keys "positive" and "unlabeled" for the PU
            datasets.
    Returns:
        tuple: A tuple containing four elements:
            - full_train_loader (PUDataLoader): DataLoader for the full training dataset.
            - train_loader (PUDataLoader): DataLoader for the training dataset.
            - validation_loader (PUDataLoader): DataLoader for the validation dataset.
            - test_loader (DataLoader): DataLoader for the testing dataset.
    """
    num_workers = 2

    # Use spawn instead of fork for better performance and memory usage

    full_train_loader: PUDataLoader = {
        "positive": DataLoader(
            data["full_train"]["positive"],
            batch_size=conf.train_batch_size,
            shuffle=False,
            num_workers=num_workers,
            persistent_workers=num_workers > 0,
            multiprocessing_context="fork",
            collate_fn=collator,
            pin_memory=True,  # Enable pin_memory for faster data transfer
        ),
        "unlabeled": DataLoader(
            data["full_train"]["unlabeled"],
            batch_size=conf.train_batch_size,
            shuffle=False,
            num_workers=num_workers,
            persistent_workers=num_workers > 0,
            multiprocessing_context="fork",
            collate_fn=collator,
            pin_memory=True,
        ),
    }

    train_loader: PUDataLoader = {
        "positive": DataLoader(
            data["train"]["positive"],
            batch_size=conf.train_batch_size,
            shuffle=True,
            num_workers=num_workers,
            persistent_workers=num_workers > 0,
            multiprocessing_context="fork",
            collate_fn=collator,
            pin_memory=True,
        ),
        "unlabeled": DataLoader(
            data["train"]["unlabeled"],
            batch_size=conf.train_batch_size,
            shuffle=True,
            num_workers=num_workers,
            persistent_workers=num_workers > 0,
            multiprocessing_context="fork",
            collate_fn=collator,
            pin_memory=True,
        ),
    }
    validation_loader: PUDataLoader = {
        "positive": DataLoader(
            data["validation"]["positive"],
            batch_size=conf.validation_batch_size,
            shuffle=False,
            num_workers=num_workers,
            persistent_workers=num_workers > 0,
            multiprocessing_context="fork",
            collate_fn=collator,
            pin_memory=True,
        ),
        "unlabeled": DataLoader(
            data["validation"]["unlabeled"],
            batch_size=conf.validation_batch_size,
            shuffle=False,
            num_workers=num_workers,
            persistent_workers=num_workers > 0,
            multiprocessing_context="fork",
            collate_fn=collator,
            pin_memory=True,
        ),
    }
    test_loader = DataLoader(
        data["test"],
        batch_size=conf.test_batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
        multiprocessing_context="fork",
        collate_fn=collator,
        pin_memory=True,
    )

    return full_train_loader, train_loader, validation_loader, test_loader


def prepare_supervised_dataloaders(
    conf: ExperimentConfig, data: dict, collator: Callable | None = None
) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader]:
    """
    Prepares the data loaders for supervised learning.

    Args:
        conf (ExperimentConfig): The configuration object containing parameters for data loading.
        data (dict): A dictionary containing the datasets. Expected keys are
            "full_train", "train", "validation", and "test".

    Returns:
        tuple: A tuple containing four elements:
            - full_train_loader (DataLoader): DataLoader for the full training dataset.
            - train_loader (DataLoader): DataLoader for the training dataset.
            - validation_loader (DataLoader): DataLoader for the validation dataset.
            - test_loader (DataLoader): DataLoader for the testing dataset.
    """
    # Optimized num_workers for better CPU utilization
    num_workers = 2

    full_train_loader = DataLoader(
        data["full_train"],
        batch_size=conf.train_batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
        multiprocessing_context="fork",
        pin_memory=True,
        collate_fn=collator,
    )

    train_loader = DataLoader(
        data["train"],
        batch_size=conf.train_batch_size,
        shuffle=True,
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
        multiprocessing_context="fork",
        pin_memory=True,
        collate_fn=collator,
    )

    validation_loader = DataLoader(
        data["validation"],
        batch_size=conf.validation_batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
        multiprocessing_context="fork",
        pin_memory=True,
        collate_fn=collator,
    )

    test_loader = DataLoader(
        data["test"],
        batch_size=conf.test_batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
        multiprocessing_context="fork",
        pin_memory=True,
        collate_fn=collator,
    )

    return full_train_loader, train_loader, validation_loader, test_loader


def get_predictions_table_artifact_name(conf: ExperimentConfig):
    artifact_args = "_".join(
        [re.sub(r"_+", "_", re.sub(r"[^a-zA-Z0-9]", "_", str(value))) for value in conf.train_setup_dict.values()]
    )
    return f"predictions_table_{artifact_args}"


class PUTrainAndPredictProcessor:
    """
    A class for training and prediction processes using the Facade pattern for both PU learning and supervised learning.

    This class combines training, prediction, and artifact logging into a single class,
    allowing for cleaner management of the processing flow.
    """

    wandb_logger: WandbLogger
    pl_module: PUModule

    def __init__(
        self,
        conf: ExperimentConfig,
        classifier: nn.Module,
        prior: float,
        full_train_loader: PUDataLoader,
        train_loader: PUDataLoader,
        validation_loader: PUDataLoader,
        test_loader: DataLoader,
    ):
        self.conf = conf
        self.prior = prior
        self.full_train_loader = full_train_loader
        self.train_loader = train_loader
        self.valididation_loader = validation_loader
        self.test_loader = test_loader

        # Initialize wandb logger
        run = wandb.init(
            project="wsl-ece",
            name="train-classifier",
            job_type="train",
            config=vars(conf),
            tags=["train"],
            dir=conf.work_dir,
        )
        run.log_code(str(Path(__file__).parent.parent))
        self.wandb_logger = WandbLogger(wandb_run=run)

        # Set up device and create Trainer
        self.device = setup_device_config()
        if _RICH_AVAILABLE and conf.enable_progress_bar:
            callbacks: list[Callback] = [RichProgressBar(), RichModelSummary()]
        else:
            callbacks = []
        callbacks.append(
            ModelCheckpoint(
                monitor="val_accuracy",
                mode="max",
                save_top_k=1,
                save_last=True,
                every_n_epochs=1,
                save_on_train_epoch_end=False,
            )
        )
        self.trainer = Trainer(
            max_epochs=conf.max_epochs,
            devices=self.device,
            logger=self.wandb_logger,
            callbacks=callbacks,
            enable_progress_bar=conf.enable_progress_bar,
            default_root_dir=conf.work_dir,
        )
        self.classifier = classifier

        with self.trainer.init_module():
            self.pl_module = PUModule(
                model=self.classifier,
                prior=self.prior,
                loss_fn=conf.loss_function,
                lr=conf.lr,
                predict_probability=conf.predict_probability,
                balanced_error=conf.balanced_error,
            )
            assert isinstance(self.train_loader, dict), "PU learning requires PUDataLoader"
            self.pl_module.estimate_steps_per_epoch(self.train_loader)
        torch.compile(self.pl_module)

    def train(self):
        """
        Training process. Combines training data using CombinedLoader and
        performs training using the Trainer for both PU learning and supervised learning.
        """
        # For PU learning, use CombinedLoader
        assert isinstance(self.train_loader, dict), "PU learning requires PUDataLoader"
        assert isinstance(self.valididation_loader, dict), "PU learning requires PUDataLoader"

        combined_train_loader = CombinedLoader(self.train_loader, mode="max_size_cycle")
        combined_validation_loader = CombinedLoader(self.valididation_loader, mode="max_size_cycle")

        self.trainer.fit(
            self.pl_module, train_dataloaders=combined_train_loader, val_dataloaders=combined_validation_loader
        )
        self.trainer.validate(model=self.pl_module, dataloaders=combined_validation_loader, ckpt_path="best")
        self.trainer.test(model=self.pl_module, dataloaders=self.test_loader, ckpt_path="best")

    def predict(self):
        """
        Prediction process. Makes predictions on test, full_train, and validation data,
        using accumulate_predictions() to aggregate results.

        Returns:
            dict: A dictionary containing predictions and labels for test, full_train, and validation datasets.
        """
        results = {}

        # For test data predictions
        logger.info("Making predictions on test data")
        test_output = accumulate_predictions(self.trainer.predict(self.pl_module, self.test_loader))
        results["test"] = {"predictions": test_output["predictions"], "labels": test_output["labels"]}

        # For full train data predictions
        logger.info("Making predictions on full train data")
        full_train_positive_preds = accumulate_predictions(
            self.trainer.predict(self.pl_module, self.full_train_loader["positive"])
        )["predictions"]
        full_train_unlabeled_output = accumulate_predictions(
            self.trainer.predict(self.pl_module, self.full_train_loader["unlabeled"])
        )
        results["full_train"] = {
            "positive_preds": full_train_positive_preds,
            "unlabeled_preds": full_train_unlabeled_output["predictions"],
            "true_labels": full_train_unlabeled_output["labels"],
        }

        # For validation data predictions
        logger.info("Making predictions on validation data")
        validation_positive_preds = accumulate_predictions(
            self.trainer.predict(self.pl_module, self.valididation_loader["positive"])
        )["predictions"]
        validation_unlabeled_output = accumulate_predictions(
            self.trainer.predict(self.pl_module, self.valididation_loader["unlabeled"])
        )
        results["validation"] = {
            "positive_preds": validation_positive_preds,
            "unlabeled_preds": validation_unlabeled_output["predictions"],
            "true_labels": validation_unlabeled_output["labels"],
        }
        return results

    def log_artifacts(self, results):
        """
        Artifact creation and logging process using wandb.
        Creates tables, adds them to Artifacts, then logs and finishes the experiment.

        The tables created include:
        - test_predictions: Contains test predictions and labels
        - full_train_positive_predictions: Contains positive sample predictions from full train data
        - full_train_unlabeled_predictions: Contains unlabeled sample predictions and true labels from full train data
        - validation_positive_predictions: Contains positive sample predictions from validation data
        - validation_unlabeled_predictions: Contains unlabeled sample predictions and true labels from validation data

        Args:
            results (dict): Dictionary containing predictions and labels for test, full_train, and validation datasets.
        """
        # Create test predictions table
        test_predictions_table = wandb.Table(
            columns=["predictions", "labels"],
            data=[
                [pred, label]
                for pred, label in zip(results["test"]["predictions"], results["test"]["labels"], strict=True)
            ],
        )

        # Create full train predictions tables
        full_train_positive_predictions_table = wandb.Table(
            columns=["predictions"], data=[[pred] for pred in results["full_train"]["positive_preds"]]
        )
        full_train_unlabeled_predictions_table = wandb.Table(
            columns=["predictions", "labels"],
            data=[
                [pred, label]
                for pred, label in zip(
                    results["full_train"]["unlabeled_preds"], results["full_train"]["true_labels"], strict=True
                )
            ],
        )

        # Create validation predictions tables
        validation_positive_predictions_table = wandb.Table(
            columns=["predictions"], data=[[pred] for pred in results["validation"]["positive_preds"]]
        )
        validation_unlabeled_predictions_table = wandb.Table(
            columns=["predictions", "labels"],
            data=[
                [pred, label]
                for pred, label in zip(
                    results["validation"]["unlabeled_preds"], results["validation"]["true_labels"], strict=True
                )
            ],
        )

        # Create artifact and add all tables
        predictions_table_artifact = wandb.Artifact(get_predictions_table_artifact_name(self.conf), type="dataset")
        predictions_table_artifact.add(test_predictions_table, "test_predictions")

        # Add full train tables
        predictions_table_artifact.add(full_train_positive_predictions_table, "full_train_positive_predictions")
        predictions_table_artifact.add(full_train_unlabeled_predictions_table, "full_train_unlabeled_predictions")

        # Add validation tables
        predictions_table_artifact.add(validation_positive_predictions_table, "validation_positive_predictions")
        predictions_table_artifact.add(validation_unlabeled_predictions_table, "validation_unlabeled_predictions")

        wandb.log_artifact(predictions_table_artifact)
        self.wandb_logger.experiment.log_artifact(predictions_table_artifact)
        self.wandb_logger.experiment.finish()

    def process(self):
        """
        As a facade, executes training, prediction, and result logging in sequence,
        and returns the final prediction results.
        """
        self.train()
        results = self.predict()
        self.log_artifacts(results)

        return results


class SupervisedTrainAndPredictProcessor:
    """
    A class for training and prediction processes using the Facade pattern for supervised learning.
    This class combines training, prediction, and artifact logging into a single class,
    allowing for cleaner management of the processing flow.
    """

    wandb_logger: WandbLogger
    pl_module: SupervisedModule

    def __init__(
        self,
        conf: ExperimentConfig,
        classifier: nn.Module,
        full_train_loader: DataLoader,
        train_loader: DataLoader,
        validation_loader: DataLoader,
        test_loader: DataLoader,
    ):
        self.conf = conf
        self.full_train_loader = full_train_loader
        self.train_loader = train_loader
        self.valididation_loader = validation_loader
        self.test_loader = test_loader

        # Initialize wandb logger
        self.wandb_logger = WandbLogger(
            project="pu-ece-estimator", name="train-classifier", job_type="train", config=vars(conf), tags=["train"]
        )

        # Set up device and create Trainer
        self.device = setup_device_config()
        if _RICH_AVAILABLE and conf.enable_progress_bar:
            callbacks: list[Callback] = [RichProgressBar(), RichModelSummary()]
        else:
            callbacks = []
        callbacks.append(
            ModelCheckpoint(
                monitor="val_accuracy",
                mode="max",
                save_top_k=1,
                save_last=True,
                every_n_epochs=1,
                save_on_train_epoch_end=False,
            )
        )
        self.trainer = Trainer(
            max_epochs=conf.max_epochs,
            devices=self.device,
            logger=self.wandb_logger,
            callbacks=callbacks,
            enable_progress_bar=conf.enable_progress_bar,
        )
        self.classifier = classifier

        with self.trainer.init_module():
            self.pl_module = SupervisedModule(
                model=self.classifier,
                loss_fn=conf.loss_function,
                lr=conf.lr,
                predict_probability=conf.predict_probability,
            )
            assert isinstance(self.train_loader, DataLoader), "Supervised learning requires DataLoader"
            self.pl_module.estimate_steps_per_epoch(self.train_loader)
        torch.compile(self.pl_module)

    def train(self):
        """
        Training process. Combines training data using CombinedLoader and
        performs training using the Trainer for supervised learning.
        """
        # For supervised learning, use standard DataLoader
        assert isinstance(self.train_loader, DataLoader), "Supervised learning requires DataLoader"
        assert isinstance(self.valididation_loader, DataLoader), "Supervised learning requires DataLoader"
        self.trainer.fit(self.pl_module, train_dataloaders=self.train_loader, val_dataloaders=self.valididation_loader)
        self.trainer.validate(model=self.pl_module, dataloaders=self.valididation_loader, ckpt_path="best")
        self.trainer.test(model=self.pl_module, dataloaders=self.test_loader, ckpt_path="best")

    def predict(self):
        """
        Prediction process. Makes predictions on test, full_train, and validation data,
        using accumulate_predictions() to aggregate results.

        Returns:
            dict: A dictionary containing predictions and labels for test, full_train, and validation datasets.
        """
        results = {}

        # For test data predictions
        logger.info("Making predictions on test data")
        test_output = accumulate_predictions(self.trainer.predict(self.pl_module, self.test_loader))
        results["test"] = {"predictions": test_output["predictions"], "labels": test_output["labels"]}

        # For full train data predictions
        logger.info("Making predictions on full train data")
        full_train_output = accumulate_predictions(self.trainer.predict(self.pl_module, self.full_train_loader))
        results["full_train"] = {
            "predictions": full_train_output["predictions"],
            "labels": full_train_output["labels"],
        }
        assert isinstance(self.valididation_loader, DataLoader), "Supervised learning requires DataLoader"

        # For validation data predictions
        logger.info("Making predictions on validation data")
        validation_output = accumulate_predictions(self.trainer.predict(self.pl_module, self.valididation_loader))
        results["validation"] = {
            "predictions": validation_output["predictions"],
            "labels": validation_output["labels"],
        }
        return results

    def log_artifacts(self, results):
        """
        Artifact creation and logging process using wandb.
        Creates tables, adds them to Artifacts, then logs and finishes the experiment.

        The tables created include:
        - test_predictions: Contains test predictions and labels
        - full_train_predictions: Contains full train predictions and labels
        - validation_predictions: Contains validation predictions and labels

        Args:
            results (dict): Dictionary containing predictions and labels for test, full_train, and validation datasets.
        """
        # Create test predictions table
        test_predictions_table = wandb.Table(
            columns=["predictions", "labels"],
            data=[
                [pred, label]
                for pred, label in zip(results["test"]["predictions"], results["test"]["labels"], strict=True)
            ],
        )

        # Create full train predictions table
        full_train_predictions_table = wandb.Table(
            columns=["predictions", "labels"],
            data=[
                [pred, label]
                for pred, label in zip(
                    results["full_train"]["predictions"], results["full_train"]["labels"], strict=True
                )
            ],
        )

        # Create validation predictions table
        validation_predictions_table = wandb.Table(
            columns=["predictions", "labels"],
            data=[
                [pred, label]
                for pred, label in zip(
                    results["validation"]["predictions"], results["validation"]["labels"], strict=True
                )
            ],
        )

        # Create artifact and add all tables
        predictions_table_artifact = wandb.Artifact(get_predictions_table_artifact_name(self.conf), type="dataset")
        predictions_table_artifact.add(test_predictions_table, "test_predictions")
        predictions_table_artifact.add(full_train_predictions_table, "full_train_predictions")
        predictions_table_artifact.add(validation_predictions_table, "validation_predictions")

        wandb.log_artifact(predictions_table_artifact)
        self.wandb_logger.experiment.log_artifact(predictions_table_artifact)
        self.wandb_logger.experiment.finish()

    def process(self):
        """
        As a facade, executes training, prediction, and result logging in sequence,
        and returns the final prediction results.
        """
        self.train()
        results = self.predict()
        self.log_artifacts(results)

        return results


def main(conf: ExperimentConfig):
    init_seed(seed=conf.seed)
    if conf.learning_type == "pu":
        tokenizer = select_tokenizer(conf.dataset)
        classifier = select_classifier(conf.dataset, tokenizer=tokenizer)
        conf.classifier = classifier  # For logging purposes
        data, prior = prepare_pu_data(conf.dataset, conf.root, conf.num_positive, tokenizer=tokenizer)
        collator = select_collator(conf.dataset, pad_id=tokenizer.pad_token_id if tokenizer else 0)
        pu_dataloaders = prepare_pu_dataloaders(conf, data, collator=collator)
        pu_full_train_loader, pu_train_loader, pu_validation_loader, test_loader = pu_dataloaders
        PUTrainAndPredictProcessor(
            conf, classifier, prior, pu_full_train_loader, pu_train_loader, pu_validation_loader, test_loader
        ).process()
    elif conf.learning_type == "supervised":
        tokenizer = select_tokenizer(conf.dataset)
        classifier = select_classifier(conf.dataset, tokenizer=tokenizer)
        conf.classifier = classifier  # For logging purposes
        data = prepare_supervised_data(conf.dataset, conf.root, conf.num_positive, tokenizer=tokenizer)
        collator = select_collator(conf.dataset, pad_id=tokenizer.pad_token_id if tokenizer else 0)
        pn_dataloaders = prepare_supervised_dataloaders(conf, data, collator=collator)
        pn_full_train_loader, pn_train_loader, pn_validation_loader, test_loader = pn_dataloaders
        SupervisedTrainAndPredictProcessor(
            conf, classifier, pn_full_train_loader, pn_train_loader, pn_validation_loader, test_loader
        ).process()
    else:
        raise ValueError(f"Unsupported learning type: {conf.learning_type}")

    # Finish the wandb run just in case
    wandb.finish()


@hydra.main(
    config_path=Path(__file__).absolute().parents[4].joinpath("configs", "pu_ece_estimator").as_posix(),
    config_name="config.yaml",
    version_base=None,
)
def hydra_main(cfg: DictConfig):
    args = ExperimentConfig(
        seed=cfg.seed,
        dataset=cfg.dataset,
        train_batch_size=cfg.get("train_batch_size", ExperimentConfig.train_batch_size),
        validation_batch_size=cfg.get("validation_batch_size", ExperimentConfig.validation_batch_size),
        test_batch_size=cfg.get("test_batch_size", ExperimentConfig.test_batch_size),
        balanced_error=cfg.get("balanced_error", ExperimentConfig.balanced_error),
        max_epochs=cfg.max_epochs,
        loss_function=LossFunction(cfg.loss_function),
        num_positive=cfg.num_positive,
        log_level=cfg.log_level,
        lr=cfg.lr,
        work_dir=cfg.get("work_dir", ExperimentConfig.work_dir),
        enable_progress_bar=cfg.get("enable_progress_bar", ExperimentConfig.enable_progress_bar),
        predict_probability=cfg.get("predict_probability", ExperimentConfig.predict_probability),
        learning_type=cfg.get("learning_type", ExperimentConfig.learning_type),
    )
    global logger
    log_level = getattr(logging, args.log_level.upper(), logging.INFO)
    logger = setup_logger(level=log_level)
    set_all_loggers_level(log_level)
    main(args)


if __name__ == "__main__":
    hydra_main()
