# This script generates all the crossmatches between provabgs and other surveys,
# tokenizes on the fly, and exports the results to the data folder.
import torch
from tqdm import tqdm
import numpy as np
import os
from astropy.table import Table, vstack, join

from mmoma.datasets.astropile import CrossMatchedAstroPileLoader
from mmoma.datasets.preprocessing import ClampImage, CropImage, PadImageBands, PadSpectra, RescaleToLegacySurvey

from utils import get_tokenizers


PROVABGS_PATH = "data/provabgs_preprocessed.hdf5"
MMU_PATHS ={'legacysurvey': "data/MultimodalUniverse/legacysurvey",
            'hsc': "data/AstroPile_v1/hsc"}
HEALPIX_REGIONS = [1708, 1709, 1643, 1640, 1642]
VERSION = '2'


def main():
    left_codecs = get_tokenizers('desi')

    provabgs = Table.read(PROVABGS_PATH)
    # Remove rows with failed provabgs fits
    provabgs = provabgs[provabgs['LOG_MSTAR'] > 0]

    # We are going to generate train and eval catalogs for both Legacy Survey and HSC 
    # imaging data.
    for survey, right_dataset_path in MMU_PATHS.items():
        print(f"Processing {survey} data")

        # Load the tokenizers
        right_codecs = get_tokenizers(survey)
        if 'catalog' in right_codecs.keys():
            right_codecs.pop('catalog')
        if 'scalar_field'  in right_codecs.keys():
            right_codecs.pop('scalar_field')

        for split in ['train', 'eval']:
            print(f"Split: {split}")

            output_file = f'data/provabgs_{survey}_{split}_v{VERSION}.fits'

            # If the output file already exists, skip
            if os.path.exists(output_file):
                print(f"Skipping {output_file}")
                continue

            loader = CrossMatchedAstroPileLoader(
                left_dataset_path = "data/MultimodalUniverse/desi",
                right_dataset_path = right_dataset_path,
                batch_size = 256,
                formatting_fns = [
                    ClampImage(),
                    CropImage(96),
                    RescaleToLegacySurvey(),
                    PadImageBands(version='oct24'),
                    PadSpectra(7800)
                ],
                include_healpix=HEALPIX_REGIONS if split == 'eval' else None,
                exclude_healpix=HEALPIX_REGIONS if split == 'train' else None
            )

            loader.setup(stage='fit')
            iterator = iter(loader.train_dataloader())

            results = None

            with torch.no_grad():
                for idx, batch in tqdm(enumerate(iterator)):
                    batch = loader.transfer_batch_to_device(batch, 'cuda', idx)
                    # For bacwards compatibility
                    batch[survey+'_image']['array'] = batch[survey+'_image']['flux']

                    # Tokenize everything we can
                    for key, codec in left_codecs.items():
                        batch['tok_'+key.lower()] = codec.encode({key: batch['desi_'+key]})
                    
                    for key, codec in right_codecs.items():
                        batch['tok_'+key.lower()] = codec.encode({key: batch[survey+'_'+key]})

                    target_ids = np.array(batch['desi_object_id']).astype(np.int64)

                    # We drop images and spectra if they are present to save space
                    if survey == 'legacysurvey':
                        #batch['rgb'] = batch['legacysurvey_image']['rgb']
                        batch.pop('legacysurvey_image')
                        batch.pop('legacysurvey_object_id')
                        batch.pop('legacysurvey_object_mask')
                        batch.pop('legacysurvey_catalog')
                    elif survey == 'hsc':
                        batch.pop('hsc_image')
                        batch.pop('hsc_object_id')
                        batch['tok_image_hsc'] = batch.pop('tok_image')
                    batch.pop('desi_spectrum')
                    batch.pop('desi_object_id')
                    batch.pop('object_id')
                    batch['tok_spectrum_desi'] = batch.pop('tok_spectrum')
                    
                    # Convert everything to numpy
                    batch = {k: v.cpu().numpy() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                    batch['TARGETID'] = target_ids
                    res = Table(batch)

                    if results is None:
                        results = res
                    else:
                        results = vstack([results, res])
            
            # Export the results to the data folder
            results = join(provabgs, results, keys='TARGETID', join_type='inner')
            print(f"Found {len(results)} matches")
            results.write(output_file)


if __name__ == '__main__':
    main()