# This script generates all the crossmatches between provabgs and other surveys,
# tokenizes on the fly, and exports the results to the data folder.
import os
import pickle

import numpy as np
import torch
from aion_eval.utils import collate_concat, index_collated
from mmoma.datasets.astropile import CrossMatchedAstroPileLoader
from mmoma.datasets.preprocessing import PadSpectra
from tqdm import tqdm

from utils import get_tokenizers

BASE_PATH = ("gaia", "/path/to/gaia/data/gaia.py", "parallax_sample")

MMU_PATHS = {
    "sdss": (
        "data/MultimodalUniverse/sdss",
        None,
        ["gaia_object_id", "sdss_object_id", "sdss_spectrum"],
    ),
    "desi": (
        "data/MultimodalUniverse/desi",
        None,
        ["gaia_object_id", "desi_object_id", "desi_spectrum"],
    ),
}
HEALPIX_REGIONS = [1708, 1709, 1643, 1640, 1642, 2698, 1081, 74, 2096, 132]


def to_numpy(batch):
    def _to_numpy(x):
        if isinstance(x, torch.Tensor):
            return x.numpy()
        return x

    return torch.utils._pytree.tree_map(_to_numpy, batch)


def main():
    for survey, (
        right_dataset_path,
        right_dataset_name,
        extract_keys,
    ) in MMU_PATHS.items():
        if isinstance(right_dataset_path, tuple):
            right_dataset_path, right_dataset_name = right_dataset_path
        else:
            right_dataset_name = None

        full_survey_name = (
            f"{survey}{'-' + right_dataset_name if right_dataset_name else ''}"
        )

        print(f"Processing {BASE_PATH[0]} X {full_survey_name}")

        loader = CrossMatchedAstroPileLoader(
            left_dataset_path=BASE_PATH[1],
            left_dataset_name=BASE_PATH[2],
            right_dataset_path=right_dataset_path,
            right_dataset_name=right_dataset_name,
            batch_size=256,
            formatting_fns=[PadSpectra(7800)],
            include_healpix=HEALPIX_REGIONS,
        )

        loader.setup(None)
        iterator = iter(loader.train_dataloader())

        results = []

        with torch.no_grad():
            for idx, batch in tqdm(enumerate(iterator)):
                # Convert everything to numpy
                batch = {
                    k: (np.array(v) if isinstance(v, list) else v)
                    for k, v in batch.items()
                    if k in extract_keys
                }
                batch = to_numpy(batch)
                batch = {
                    k: (v.astype(int) if k == "gaia_object_id" else v)
                    for k, v in batch.items()
                    if k in extract_keys
                }
                results.append(batch)

        results = collate_concat(results)

        with open(f"data/{BASE_PATH[0]}_x_{full_survey_name}_v1.pkl", "wb") as f:
            pickle.dump(results, f)


if __name__ == "__main__":
    main()
