"""
Generate synthetic datasets for pretraining.
This script generates synthetic datasets for pretraining the model and saves them to disk.
"""

import multiprocessing as mp
import random
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
import xarray as xr
from loguru import logger
from tqdm import tqdm

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import GeneratorName
from tabicl.core.trainer_pretrain_init import create_synthetic_dataset
from tabicl.data.synthetic_generator_forest import synthetic_dataset_generator_forest
from tabicl.data.synthetic_generator_tabpfn import synthetic_dataset_generator_tabpfn
from tabicl.utils.set_seed import seed_worker


@dataclass
class ConfigPretrainingGeneration:
    n_workers: int
    n_datasets: int
    n_samples: int
    min_features: int
    max_features: int
    max_classes: int


def generate_pretraining_datasets(cfg: ConfigPretrainingGeneration):

    generate_pretraining_datasets_(cfg, "tabpfn")
    generate_pretraining_datasets_(cfg, "forest")



def process_generator_output(cfg: ConfigPretrainingGeneration, folder_name: str, seed: int, queue: mp.Queue):

    if folder_name == "tabpfn":
        generator_factory = synthetic_dataset_generator_tabpfn
    elif folder_name == "forest":
        generator_factory = synthetic_dataset_generator_forest

    generator = generator_factory(
        n_samples=cfg.n_samples,
        min_features=cfg.min_features,
        max_features=cfg.max_features,
        max_classes=cfg.max_classes,
    )

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    while True:
        x, y = next(generator)

        if len(y.shape) > 1:
            y = y[:, 0]

        queue.put((x, y))


def identity(x):
    return x

def generate_pretraining_datasets_(cfg: ConfigPretrainingGeneration, folder_name: str):

    
    cfg_pretrain = ConfigPretrain.load(Path("outputs/done/foundation_mix_600k_finetune/config_pretrain.yaml"))
    cfg_pretrain.data.generator = GeneratorName(folder_name)

    synthetic_dataset = create_synthetic_dataset(cfg_pretrain)
    synthetic_dataloader = torch.utils.data.DataLoader(
        synthetic_dataset,
        batch_size=1,
        collate_fn=identity,
        pin_memory=False,
        num_workers=cfg.n_workers,
        persistent_workers=cfg.n_workers > 0,
        worker_init_fn=seed_worker,
    )

    synthetic_iterator = iter(synthetic_dataloader)
    path = Path(f"data/datasets_pretraining/{folder_name}")
    path.mkdir(parents=True, exist_ok=True)

    for i in tqdm(range(cfg.n_datasets)):
        batch = next(synthetic_iterator)[0]
        
        ds = xr.Dataset(
            {
                "x_support": (["n_support", "feature"], batch["x_support"]),
                "x_query"  : (["n_query"  , "feature"], batch["x_query"]),
                "y_support": (["n_support"], batch["y_support"]),
                "y_query"  : (["n_query"  ], batch["y_query"])
            },
            coords={
                "n_support": range(batch["x_support"].shape[0]),
                "n_query"  : range(batch["x_query"].shape[0]),
                "feature"  : range(batch["x_support"].shape[1])
            },
            attrs={
                "id": i,
                "min_features": cfg.min_features,
                "max_features": cfg.max_features,
                "max_classes": cfg.max_classes
            }
        )

        ds.to_netcdf(path / f"dataset_{i}.nc")


    logger.info(f"Finished generating datasets for {folder_name}")


if __name__ == "__main__":

    mp.set_start_method("forkserver")

    cfg = ConfigPretrainingGeneration(
        n_workers=124,
        n_datasets=10_000,
        n_samples=1024 + 256,
        min_features=1,
        max_features=100,
        max_classes=10
    )

    generate_pretraining_datasets(cfg)




