# This script generates all the crossmatches between provabgs and other surveys,
# tokenizes on the fly, and exports the results to the data folder.
import os
import pickle

import numpy as np
import torch
from aion_eval.utils import collate_concat, index_collated
from astropy.table import Table
from mmoma.datasets.astropile import CrossMatchedAstroPileLoader
from mmoma.datasets.preprocessing import PadSpectra
from tqdm import tqdm

from utils import get_tokenizers

DD_PATH = "/path/to/desi/data/DESI_EDR_DDPAYNE.fits"
MMU_PATHS = {
    "gaia": ("/path/to/gaia/data/gaia.py", "parallax_sample"),
}
HEALPIX_REGIONS = [1708, 1709, 1643, 1640, 1642]
VERSION = "1"


rename_map = {
    "tok_desi_spectrum": "tok_spectrum_desi",
    "tok_bp_coefficients": "tok_xp_bp",
    "tok_rp_coefficients": "tok_xp_rp",
    "tok_phot_g_mean_flux": "tok_flux_g_gaia",
    "tok_phot_bp_mean_flux": "tok_flux_bp_gaia",
    "tok_phot_rp_mean_flux": "tok_flux_rp_gaia",
}


def to_numpy(batch):
    def _to_numpy(x):
        if isinstance(x, torch.Tensor):
            return x.cpu().numpy()
        return x

    return torch.utils._pytree.tree_map(_to_numpy, batch)


def main():
    os.makedirs("data", exist_ok=True)

    left_codecs = get_tokenizers("desi")

    catalog = Table.read(DD_PATH)
    catalog = catalog[catalog["CHISQ_FLAG"] < 3]

    for survey, right_dataset_path in MMU_PATHS.items():
        if isinstance(right_dataset_path, tuple):
            right_dataset_path, right_dataset_name = right_dataset_path
        else:
            right_dataset_name = None

        full_survey_name = (
            f"{survey}{'-' + right_dataset_name if right_dataset_name else ''}"
        )

        print(f"Processing {full_survey_name} data")

        # Load the tokenizers
        right_codecs = get_tokenizers(full_survey_name)

        for split in ["train", "eval"]:
            print(f"Split: {split}")

            output_file = f"data/desiddpayne_{full_survey_name}_{split}_v{VERSION}.pkl"

            # If the output file already exists, skip
            if os.path.exists(output_file):
                print(f"Skipping {output_file}")
                continue

            loader = CrossMatchedAstroPileLoader(
                left_dataset_path="data/MultimodalUniverse/desi",
                left_dataset_name=None,
                right_dataset_path=right_dataset_path,
                right_dataset_name=right_dataset_name,
                batch_size=256,
                formatting_fns=[PadSpectra(7800)],
                include_healpix=HEALPIX_REGIONS if split == "eval" else None,
                exclude_healpix=HEALPIX_REGIONS if split == "train" else None,
            )

            loader.setup(stage="fit")
            iterator = iter(loader.train_dataloader())

            results = []

            with torch.no_grad():
                for idx, batch in tqdm(enumerate(iterator)):
                    batch = loader.transfer_batch_to_device(batch, "cuda", idx)

                    # Tokenize everything we can
                    for key, codec in left_codecs.items():
                        batch["tok_desi_" + key.lower()] = codec.encode(
                            {key: batch["desi_" + key]}
                        )

                    for key, codec in right_codecs.items():
                        batch["tok_" + key.lower()] = codec.encode(
                            {key: batch[survey + "_" + key]}
                        )

                    batch.pop(f"{survey}_object_id")
                    batch.pop("desi_object_id")

                    for key in rename_map:
                        batch[rename_map[key]] = batch.pop(key)

                    # Convert everything to numpy
                    batch = to_numpy(batch)
                    batch["object_id"] = np.array(batch["object_id"]).astype(int)

                    _, ix_cat, ix_batch = np.intersect1d(
                        catalog["TARGET_ID"].data.astype(int),
                        batch["object_id"],
                        return_indices=True,
                        assume_unique=False,
                    )

                    batch = index_collated(batch, ix_batch)

                    for k in catalog.columns:
                        batch[k] = catalog[k][ix_cat].data

                    results.append(batch)

            results = collate_concat(results)

            with open(output_file, "wb") as f:
                pickle.dump(results, f)


if __name__ == "__main__":
    main()
