# This script generates all the crossmatches between provabgs and other surveys,
# tokenizes on the fly, and exports the results to the data folder.
import torch
from tqdm import tqdm
import numpy as np
import os
from astropy.table import Table, vstack, join

from mmoma.datasets.astropile import CrossMatchedAstroPileLoader
from mmoma.datasets.preprocessing import (
    ClampImage,
    CropImage,
    PadImageBands,
    RescaleToLegacySurvey,
)

from utils import get_tokenizers

from dataclasses import dataclass


@dataclass
class Dataset:
    dataset: str
    path: str


LEFT_DATASET = Dataset(
    dataset="gz10", path="data/MultimodalUniverse/gz10"
)

RIGHT_DATASETS = [
    Dataset(
        dataset="legacysurvey",
        path="data/MultimodalUniverse/legacysurvey",
    ),
    Dataset(dataset="hsc", path="data/AstroPile_v1/hsc"),
]

VERSION = "3" # 1: for oct 24, 2: for dec 24 # 3: for raw images


def main():
    left_codecs = get_tokenizers(LEFT_DATASET.dataset)
    for right_dataset in RIGHT_DATASETS:
        survey = right_dataset.dataset
        right_codecs = get_tokenizers(survey)
        if "catalog" in right_codecs.keys():
            right_codecs.pop("catalog")
        if "scalar_field" in right_codecs.keys():
            right_codecs.pop("scalar_field")

        output_file = f"data/{LEFT_DATASET.dataset}_{survey}_v{VERSION}.fits"

        if os.path.exists(output_file):
            print(f"Skipping {output_file}")
            continue

        loader = CrossMatchedAstroPileLoader(
            left_dataset_path=LEFT_DATASET.path,
            right_dataset_path=right_dataset.path,
            batch_size=256,
            formatting_fns=[
                ClampImage(),
                CropImage(96),
                RescaleToLegacySurvey(),
                PadImageBands(version="oct24"), # v1: oct24, v2: dec24 (latest)
            ],
        )

        loader.setup(stage="fit")
        iterator = iter(loader.train_dataloader())

        results = None

        with torch.no_grad():
            for idx, batch in tqdm(enumerate(iterator)):
                batch = loader.transfer_batch_to_device(batch, "cuda", idx)
                # For bacwards compatibility
                batch[survey + "_image"]["array"] = batch[survey + "_image"]["flux"]

                # Tokenize everything we can
                for key, codec in left_codecs.items():
                    batch["tok_" + key.lower()] = codec.encode(
                        {key: batch[f"{LEFT_DATASET.dataset}_{key}"]}
                    )

                for key, codec in right_codecs.items():
                    if key == "image" and VERSION == '3':
                        batch["tok_" + key.lower()] = batch[f"{survey}_{key}"]["array"]
                    else:
                        batch["tok_" + key.lower()] = codec.encode(
                            {key: batch[f"{survey}_{key}"]}
                        )

                target_ids = np.array(
                    batch[f"{LEFT_DATASET.dataset}_object_id"]
                ).astype(np.int64)

                # We drop images and spectra if they are present to save space
                if survey == "legacysurvey":
                    batch["rgb"] = batch["legacysurvey_image"]["rgb"]
                    batch.pop("legacysurvey_image")
                    batch.pop("legacysurvey_object_id")
                    batch.pop("legacysurvey_object_mask")
                    batch.pop("legacysurvey_catalog")
                elif survey == "hsc":
                    batch.pop("hsc_image")
                    batch.pop("hsc_object_id")
                    batch["tok_image_hsc"] = batch.pop("tok_image")
                batch.pop("object_id")

                # Convert everything to numpy
                batch = {
                    k: v.cpu().numpy() if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()
                }
                batch["TARGETID"] = target_ids
                res = Table(batch)

                if results is None:
                    results = res
                else:
                    results = vstack([results, res])

        # Export the results to the data folder
        print(f"Found {len(results)} matches")
        results.write(output_file)


if __name__ == "__main__":
    main()
