
import torch
import torch.distributed as dist
from loguru import logger

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.collator import CollatorWithPadding
from tabicl.core.enums import ModelName
from tabicl.core.flops import FLOPSNotImplementedError, calculate_flops
from tabicl.data.dataset_synthetic import SyntheticDataset
from tabicl.utils.set_seed import seed_worker


def log_parameter_count(cfg: ConfigPretrain, model: torch.nn.Module) -> None:
    
    if dist.get_rank() == 0:
        logger.info(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")


def log_flops(cfg: ConfigPretrain) -> None:

    if dist.get_rank() == 0:

        try:
            flops = calculate_flops(cfg)
        except FLOPSNotImplementedError as e:
            logger.error(e)
            return

        logger.info(f"Total Pretraining FLOPs: {flops:.3e}")


def prepare_ddp_model(cfg: ConfigPretrain, model: torch.nn.Module) -> torch.nn.Module:

    if cfg.use_ddp:
        device = torch.cuda.current_device()
        return torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=False)
    
    return model


def create_synthetic_dataloader(cfg: ConfigPretrain, synthetic_dataset: SyntheticDataset) -> torch.utils.data.DataLoader:

    match cfg.model_name:
        case ModelName.TABPFN | ModelName.FOUNDATION | ModelName.FOUNDATION_FLASH:
            pad_to_max_features = True
        case ModelName.TAB2D:
            pad_to_max_features = False
        case _:
            raise NotImplementedError(f"Model {cfg.model_name} not implemented")
        

    return torch.utils.data.DataLoader(
        synthetic_dataset,
        batch_size=cfg.optim.batch_size,
        collate_fn=CollatorWithPadding(
            max_features=cfg.data.max_features,
            pad_to_max_features=pad_to_max_features,
            # shuffle_features=cfg.preprocessing.shuffle_features
        ),
        pin_memory=True,
        num_workers=cfg.workers_per_gpu,
        persistent_workers=cfg.workers_per_gpu > 0,
        worker_init_fn=seed_worker,
    )