# This script preprocesses a parent sample for lens search
# following closely the choices of https://academic.oup.com/mnras/article/535/2/1625/7842018
# What we do here is the following
#  - Select HSC and LS stamps with stellar mass and redshift cuts following the
#   choices of the paper
#  - Crossmatch HSC and LS to be able to work on both datasets
#  - Tokenize all the data and save it to disk as a big table
# This script generates all the crossmatches between provabgs and other surveys,
# tokenizes on the fly, and exports the results to the data folder.

import torch
from tqdm import tqdm
import numpy as np
from astropy.table import Table, vstack, join

from mmoma.datasets.astropile import CrossMatchedAstroPileLoader
from mmoma.datasets.preprocessing import ClampImage, CropImage, PadImageBands, RescaleToLegacySurvey
from aion_eval.baselines.moco_v2 import Moco_v2, stein_decals_to_rgb

from utils import get_tokenizers


HSC_CATALOG = "data/hsc_pdr3_catalog.fits"
MMU_PATHS ={'legacysurvey': "data/MultimodalUniverse/legacysurvey",
            'hsc': "data/AstroPile_v1/hsc"}
# HEALPIX_REGIONS = [1708, 1709, 1643, 1640, 1642]
VERSION = '2'

def main():
    left_codecs = get_tokenizers('hsc')
    right_codecs = get_tokenizers('legacysurvey')
    right_codecs.pop('catalog')
    right_codecs.pop('scalar_field')

    # set up stein model
    stein_model = Moco_v2.load_from_checkpoint(checkpoint_path="data/stein.ckpt").encoder_q

    # Load output of query, and apply selections
    catalog = Table.read(HSC_CATALOG)
    catalog = catalog[
        (catalog['photoz_median'] >= 0.2) & 
        (catalog['photoz_median'] <= 1.2) & 
        (catalog['stellar_mass'] > 5.0e10) & 
        ((catalog['sfr'] / catalog['stellar_mass']) < 1.0e-10)
    ]

    print(f"Found {len(catalog)} objects after cuts")
    output_file = f'data/lens_parent_sample_v{VERSION}.fits'

    loader = CrossMatchedAstroPileLoader(
        left_dataset_path = MMU_PATHS['hsc'],
        right_dataset_path = MMU_PATHS['legacysurvey'],
        batch_size = 256,
        num_workers = 32,
        formatting_fns = [
            ClampImage(),
            CropImage(96),
            RescaleToLegacySurvey(),
            PadImageBands(version='oct24'),
        ]
    )

    loader.setup(stage='fit')
    iterator = iter(loader.train_dataloader())

    results = []

    with torch.no_grad():
        for idx, batch in tqdm(enumerate(iterator)):
            batch = loader.transfer_batch_to_device(batch, 'cuda', idx)
            # For bacwards compatibility
            batch['hsc_image']['array'] = batch['hsc_image']['flux']
            batch['legacysurvey_image']['array'] = batch['legacysurvey_image']['flux']

            # Tokenize everything we can
            for key, codec in left_codecs.items():
                batch['tok_'+key.lower()] = codec.encode({key: batch['hsc_'+key]})
            batch['tok_image_hsc'] = batch.pop('tok_image')
            
            for key, codec in right_codecs.items():
                batch['tok_'+key.lower()] = codec.encode({key: batch['legacysurvey_'+key]})

            # Tokenizer with stein images
            stein_images = [stein_decals_to_rgb(b[[5,6,8],:,:].unsqueeze(0)).squeeze() for b in batch['legacysurvey_image']['array']]
            stein_images = torch.stack(stein_images).to('cuda')
            batch['stein_embedding'] = stein_model(stein_images)

            target_ids = np.array(batch['hsc_object_id']).astype(np.int64)

            # We drop images and spectra if they are present to save space
            batch['rgb'] = batch['legacysurvey_image']['rgb']
            batch.pop('legacysurvey_image')
            batch.pop('legacysurvey_object_id')
            batch.pop('legacysurvey_object_mask')
            batch.pop('legacysurvey_catalog')
            batch.pop('hsc_image')
            batch.pop('hsc_object_id')
            
            # Convert everything to numpy
            batch = {k: v.cpu().numpy() if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            batch['object_id'] = target_ids
            res = Table(batch)

            # Join with parent catalog to remove objects that don't match early
            res = join(res, catalog, keys='object_id', join_type='inner')
            results.append(res)
            
    # Fusing all results 
    results = vstack(results)
    
    print(f"Found {len(results)} matches")
    results.write(output_file)


if __name__ == '__main__':
    main()