import os
import h5py
import numpy as np
import healpy as hp
import astropy.units as u
import torch
import warnings
import requests
import multiprocessing as mp

from tqdm import tqdm
from argparse import ArgumentParser
from astropy.coordinates import SkyCoord
from astropy.table import Table, hstack, vstack
from torch.utils.data.dataloader import default_collate

from utils import get_tokenizers
from aion_eval.baselines.moco_v2 import Moco_v2, stein_decals_to_rgb
from mmoma.datasets.preprocessing import (
    ClampImage, CropImage, PadImageBands, PadSpectra, RescaleToLegacySurvey
)


GZ5_LINK = "https://zenodo.org/records/4573248/files/gz_decals_volunteers_5.csv?download=1"
LS_DIR = "data/MultimodalUniverse/legacysurvey/dr10_south_21/"

QUERY_KEYS = ['object_id', 'ra', 'dec']
FLUX_KEYS = ['FLUX_G', 'FLUX_R', 'FLUX_I', 'FLUX_Z', 'FLUX_W1', 'FLUX_W2', 'FLUX_W3', 'FLUX_W4']
IMAGE_KEYS = [
    'image_array', 'image_band', 'image_mask',
    'image_rgb', 'image_scale', 'image_psf_fwhm', 'image_ivar'
]

# Image preprocessing functions
preprocessing_fns = [
    ClampImage(),
    CropImage(96),
    RescaleToLegacySurvey(),
    PadImageBands(version='oct24'),
]

# Globals used by each worker
TOKENIZERS = None
STEIN_MODEL = None

def init_worker():
    """Initializer for each worker process."""
    global TOKENIZERS, STEIN_MODEL
    TOKENIZERS = get_tokenizers("legacysurvey", device='cpu')
    STEIN_MODEL = Moco_v2.load_from_checkpoint(checkpoint_path="data/stein.ckpt").encoder_q

def _download_gz5_decals(survey_path: str) -> None:
    """Download Galaxy Zoo 5 classifications."""
    response = requests.get(GZ5_LINK)
    with open(survey_path, "wb") as f:
        f.write(response.content)

def _cross_match_tables(table1: Table, table2: Table, max_sep: float = 0.5) -> tuple[Table, Table]:
    """Cross-match two tables."""
    coords1 = SkyCoord(ra=table1["ra"] * u.degree, dec=table1["dec"] * u.degree)
    coords2 = SkyCoord(ra=table2["ra"] * u.degree, dec=table2["dec"] * u.degree)

    idx, d2d, _ = coords1.match_to_catalog_sky(coords2)
    max_sep = max_sep * u.arcsec
    sep_constraint = d2d < max_sep

    return table1[sep_constraint], table2[idx[sep_constraint]]

def process_healpix(h, gz5_decals, batch_size=256):
    """
    Process a single healpixel:
      1. Load LegacySurvey data for that healpixel (HDF5 file)
      2. Cross-match with gz5_decals subset
      3. Tokenize fluxes and images on CPU
      4. Return the combined (vstack) table of matched samples

    Returns:
        Table or None if no matches
    """
    global TOKENIZERS, STEIN_MODEL

    # HDF5 path
    hdf5_path = os.path.join(LS_DIR, f"healpix={h}", "001-of-001.hdf5")
    if not os.path.isfile(hdf5_path):
        # Could print a warning or just return None
        print(f"Warning: HDF5 file not found for healpix={h}: {hdf5_path}")
        return None

    # Subset from GZ5 for this healpixel
    gz_subset = gz5_decals[gz5_decals["healpix"] == h]

    # Build astropy table from HDF5
    with h5py.File(hdf5_path, "r") as f:
        ls_dict = {k: f[k][:] for k in QUERY_KEYS}
        ls = Table(ls_dict)

        ls_matched, gz_matched = _cross_match_tables(ls, gz_subset)

        # No or single match => skip
        if len(ls_matched) < 2:
            return None

        gz_matched['object_id'] = ls_matched['object_id']

        # Build an id -> index dictionary for quick lookup
        full_object_ids = ls["object_id"][:]
        id_to_idx = {oid: i for i, oid in enumerate(full_object_ids)}

        matched_indices = []
        for oid in ls_matched["object_id"]:
            matched_indices.append(id_to_idx[oid])
        matched_indices = np.array(matched_indices, dtype=int)

        # Gather the needed data
        matched_data = {col: f[col][matched_indices] for col in FLUX_KEYS + IMAGE_KEYS}

    # Prepare output tokens
    matched_tokens = {}

    # 1) Tokenize fluxes
    for key in FLUX_KEYS:
        codec = TOKENIZERS[key]
        # shape: (N, 1) so that encoding can handle it easily
        data = torch.tensor(matched_data[key].reshape(-1, 1), dtype=torch.float32)
        with torch.no_grad():
            toks = codec.encode({key: data}).numpy()
        # fix byte order if needed
        toks = toks.astype(toks.dtype.newbyteorder('='))
        matched_tokens[f"tok_{key.lower()}"] = toks

    # 2) Tokenize images
    img_codec = TOKENIZERS['image']
    tok_imgs, stein_embs = [], []
    buffer_imgs, buffer_stein = [], []
    num_samples = len(matched_data['image_array'])

    for i in range(num_samples):
        # Build the expected input dict for the pipeline
        img_dict = {
            "image": {
                "flux": matched_data['image_array'][i],
                "band": [b.decode('utf-8').upper() for b in matched_data['image_band'][i]],
                "mask": matched_data['image_mask'][i],
                "rgb": matched_data['image_rgb'][i],
                "ivar": matched_data['image_ivar'][i],
                "scale": matched_data['image_scale'][i],
                "psf_fwhm": matched_data['image_psf_fwhm'][i]
            }
        }

        # Apply preprocessing
        for fn in preprocessing_fns:
            img_dict = fn(img_dict)

        # Apply rgb transform for stein model
        rgb_img = stein_decals_to_rgb(torch.tensor(img_dict['image']['flux'][[5,6,8],:,:]).unsqueeze(0))

        # Append to buffer
        buffer_stein.append(rgb_img)
        buffer_imgs.append(img_dict)

        # Encode in batches
        # If we reached batch_size or the end, encode the batch
        if (i % batch_size == 0 and i > 0) or (i == num_samples - 1):

            # tokenizer model
            collated = default_collate(buffer_imgs)
            collated['image'] = {k: v for k, v in collated['image'].items() if k != 'band'}
            collated['image']['array'] = collated['image']['flux']

            with torch.no_grad():
                timg = img_codec.encode(collated).numpy()

            timg = timg.astype(timg.dtype.newbyteorder('='))
            tok_imgs.append(timg)

            # stein model
            with torch.no_grad():
                timg_stein = STEIN_MODEL(torch.cat(buffer_stein, dim=0)).detach().numpy()
            stein_embs.append(timg_stein)

            # reset buffer
            buffer_imgs, buffer_stein = [], []

    # Combine all image tokens
    if len(tok_imgs) > 0:
        matched_tokens['tok_image'] = np.concatenate(tok_imgs, axis=0)
        matched_tokens['stein_emb'] = np.concatenate(stein_embs, axis=0)
    else:
        return None  # no images for some reason

    # Convert matched_tokens to a Table
    matched_tokens_table = Table(matched_tokens)

    # Combine with matched metadata from Galaxy Zoo
    final_matched = hstack([matched_tokens_table, gz_matched])
    return final_matched

def process_healpix_wrapper(args):
    """Wrapper to unpack tuple arguments for imap or map."""
    return process_healpix(*args)

def main(args):
    """Orchestrates the workflow."""
    # Suppress warnings
    warnings.simplefilter(action='ignore', category=FutureWarning)

    # Args
    healpix_nside = args.healpix_nside
    batch_size = args.batch_size
    num_workers = args.num_workers

    # 1) Download GZ5 if needed
    gz_csv_path = "data/gz_decals_volunteers_5.csv"
    if not os.path.exists(gz_csv_path):
        os.makedirs("data", exist_ok=True)
        _download_gz5_decals(gz_csv_path)
    else:
        print(f"Using existing GZ5 file: {gz_csv_path}")

    # 2) Read GZ5, compute healpix
    gz5_decals = Table.read(gz_csv_path, format="ascii")
    ra, dec = gz5_decals["ra"], gz5_decals["dec"]
    healpix_indices = hp.ang2pix(healpix_nside, ra, dec, lonlat=True, nest=True)
    gz5_decals["healpix"] = healpix_indices

    healpixes = np.unique(healpix_indices)
    print(f"Number of unique healpixes: {len(healpixes)}")

    # 3) Multiprocess on top-level process_healpix
    with mp.Pool(processes=num_workers, initializer=init_worker) as pool:
        tasks = [(h, gz5_decals, batch_size) for h in healpixes]

        # 4) Collect results
        results = []
        for res in tqdm(
            pool.imap_unordered(process_healpix_wrapper, tasks),
            total=len(tasks),
            desc="Processing healpixels"
        ):
            if res is not None and len(res) > 0:
                results.append(res)

    if not results:
        print("No matches found. Exiting.")
        return

    # 5) Combine tables
    all_final_matches = vstack(results)
    out_path = "data/gz5_legacysurvey_matches_w_stein.hdf5"
    all_final_matches.write(out_path, overwrite=True)
    print(f"Done! Wrote final table to {out_path}")

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--healpix_nside", type=int, default=16)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--num_workers", type=int, default=4)
    args = parser.parse_args()

    main(args)