import os
import numpy as np
from astropy.table import Table
from tqdm import tqdm
from provabgs import models as Models


PROVABGS_ONLINE_FILE = "https://data.desi.lbl.gov/public/edr/vac/edr/provabgs/v1.0/BGS_ANY_full.provabgs.sv3.v0.hdf5"
PROVABGS_SAVE_PATH = "data/preprocessed.hdf5"


def _download_data(save_path: str):
    if not os.path.exists(save_path):
        print("Downloading PROVABGS data...")
        os.system(f"wget {PROVABGS_ONLINE_FILE} -O {save_path}")
        print("Downloaded PROVABGS data successfully!")

    else:
        print("PROVABGS data already exists...skipping download.")


def _get_best_fit(provabgs: Table):
    m_nmf = Models.NMF(burst=True, emulator=True)

    # Filter out galaxies with no best fit model
    provabgs = provabgs[
        (provabgs["PROVABGS_LOGMSTAR_BF"] > 0)
        * (provabgs["MAG_G"] > 0)
        * (provabgs["MAG_R"] > 0)
        * (provabgs["MAG_Z"] > 0)
    ]

    # Get the thetas and redshifts for each galaxy
    thetas = provabgs["PROVABGS_THETA_BF"][:, :12]
    zreds = provabgs["Z_HP"]

    Z_mw = []  # Stellar Metallicitiy
    tage_mw = []  # Age
    avg_sfr = []  # Star-Forming Region
    for i in tqdm(range(len(thetas), desc="Calculating best-fit properties using the PROVABGS model")):
        theta = thetas[i]
        zred = zreds[i]

        # Calculate properties using the PROVABGS model
        Z_mw.append(m_nmf.Z_MW(theta, zred=zred))
        tage_mw.append(m_nmf.tage_MW(theta, zred=zred))
        avg_sfr.append(m_nmf.avgSFR(theta, zred=zred))

    # Replace name
    provabgs["LOG_MSTAR"] = provabgs["PROVABGS_LOGMSTAR_BF"]
    del provabgs["PROVABGS_LOGMSTAR_BF"]

    # Add new properties to the table
    provabgs["LOG_Z_MW"] = np.log(np.array(Z_mw)).astype(np.float32)
    provabgs["TAGE_MW"] = np.array(tage_mw, dtype=np.float32)
    provabgs["sSFR"] = (np.log(np.array(avg_sfr)) - provabgs["LOG_MSTAR"]).astype(np.float32)
    return provabgs


def main():
    _download_data(PROVABGS_SAVE_PATH)
    provabgs = Table.read(PROVABGS_SAVE_PATH)

    provabgs = _get_best_fit(provabgs)
    provabgs.write(PROVABGS_SAVE_PATH, overwrite=True)


if __name__ == "__main__":
    main()