import numpy as np
from tqdm import tqdm
from astropy.table import Table, join

from mmoma.datasets.astropile import CrossMatchedAstroPileLoader
from mmoma.datasets.preprocessing import ClampImage, CropImage, RescaleToLegacySurvey

HEALPIX_REGIONS = [1708, 1709, 1643, 1640, 1642]
splits = ['train', 'eval']
version = 2
survey = "legacysurvey"

def main():
    for split in splits:
        loader = CrossMatchedAstroPileLoader(
                        left_dataset_path = "data/MultimodalUniverse/provabgs",
                        right_dataset_path = f"data/MultimodalUniverse/{survey}",
                        batch_size = 256,
                        formatting_fns = [
                            ClampImage(),
                            CropImage(96),
                            RescaleToLegacySurvey(),
                        ],
                        num_workers = 10,
                        include_healpix=HEALPIX_REGIONS if split == 'eval' else None,
                        exclude_healpix=HEALPIX_REGIONS if split == 'train' else None
                    )

        loader.setup(stage='fit')
        iter_loader = iter(loader.train_dataloader())

        target_ids, images = [], []
        for batch in tqdm(iter_loader):
            target_ids.extend(batch['object_id'])
            images.append(batch['legacysurvey_image']['flux'].numpy())

        new_table = Table()
        new_table['TARGETID'] = np.array(target_ids).astype(np.int64)
        new_table['image'] = np.concatenate(images, axis=0)
        
        dset_path = f"data/provabgs_legacysurvey_{split}_v{version}.fits"
        old_table = Table.read(dset_path)

        new_path = dset_path.replace('.fits', '_w_image.fits')
        joint_table = join(old_table, new_table, keys='TARGETID')

        print(f"Writing to {new_path}, total rows: {len(joint_table)}")
        joint_table.write(new_path, overwrite=False)

    print("Done!")

if __name__ == "__main__":
    main()