# Analysis script for provabgs, runs inference of the models in all situations of interest
# and saves the results for later ploting and analysis 
import torch
import numpy as np
from tqdm import tqdm
import argparse
from astropy.table import Table
import wandb 
import glob
from aion_eval.benchmarks.provabgs.models import AIONCrossAttentionProbing, AIONLinearProbing, MultiBackbonePROVABGSModel
import os
import pandas as pd

# These values are obtained from data/provabgs_legacysurvey_train_v1.fits

# Redshift
mean_redshift = 0.2400511627531073
std_redshift = 0.11494735079568326

# Stellar metallicity
mean_mstar = 10.674901294629791
std_mstar = 0.6968279084439876

# Age
mean_age = 8.6095915
std_age = 1.6635737

# Log Metallicity
mean_logzmw = -5.339284
std_logzmw = 0.9432177

# sSFR
mean_ssfr = -11.151731
std_ssfr = 2.2690578

# WandB API
wandb_api = wandb.Api()
project_name = ''


def undo_normalization(predictions):
    m = np.stack([mean_redshift, mean_mstar, mean_age, mean_logzmw, mean_ssfr]).reshape(1, -1)
    s = np.stack([std_redshift, std_mstar, mean_age, mean_logzmw, mean_ssfr]).reshape(1, -1)
    return predictions * s + m


def batch_process_catalog(catalog, 
                    model, 
                    input_keys,
                    batch_size=128):
    """
    Processes a catalog of data in batches using a specified model.

    Parameters:
        catalog (dict): A dictionary containing the data to be processed. Each key corresponds to a different input feature.
        model (torch.nn.Module): The model to be used for processing the data.
        input_keys (list): A list of keys to be used from the catalog for processing.
        batch_size (int, optional): The size of each batch. Default is 512.

    Returns:
        np.ndarray: An array of predictions generated by the model for the entire catalog.
    """
    predictions = []
    target_ids = []
    num_batches = len(catalog) // batch_size + (1 if len(catalog) % batch_size != 0 else 0)

    with torch.no_grad():
        for i in tqdm(range(num_batches)):
            
            batch = {}
            for k in input_keys:
                dat = catalog[k][i*batch_size:(i+1)*batch_size]
                
                if k == 'image':
                    dat = np.asarray(dat, dtype=np.float32)
                    batch[k] = torch.from_numpy(dat).to('cuda')
                    
                elif k == 'tok_image_hsc' or k == 'tok_image':
                    dat = np.asarray(dat, dtype=np.int32)
                    B = dat.shape[0]
                    batch[k] = torch.from_numpy(dat).to('cuda').view(B, -1)
                
                else:
                    dat = np.asarray(dat, dtype=np.int32)
                    batch[k] = torch.from_numpy(dat).to('cuda')

                # Sanity check
                if torch.isnan(batch[k]).any():
                    raise ValueError('nan in batch')
                
            # Apply model
            res = model(batch)

            predictions.append(res.cpu().numpy())
            target_ids.append(catalog['TARGETID'][i*batch_size:(i+1)*batch_size])

    predictions = np.concatenate(predictions, axis=0)
    predictions = Table(
        {
            'TARGETID': np.concatenate(target_ids), 
            'AION_Z_HP': predictions[:, 0] * std_redshift + mean_redshift,
            'AION_LOGMSTAR': predictions[:, 1] * std_mstar + mean_mstar,
            'AION_TAGE_MW': predictions[:, 2] * std_age + mean_age,
            'AION_LOG_Z_MW': predictions[:, 3] * std_logzmw + mean_logzmw,
            'AION_sSFR': predictions[:, 4] * std_ssfr + mean_ssfr
        }
    )
    return predictions


def get_model(run, run_id):
    model_name = run.config['model']['class_path']
    model_path = run.config['trainer']['default_root_dir']+f'/{project_name}/{run_id}/checkpoints/*.ckpt'
    print(f"Loading model {model_name} from {model_path}")
    
    model_path = glob.glob(model_path)[0] # There should only be one model checkpoint
    if 'AIONLinearProbing' in model_name:
        model = AIONLinearProbing.load_from_checkpoint(model_path)
    elif 'AIONCrossAttentionProbing' in model_name:
        model = AIONCrossAttentionProbing.load_from_checkpoint(model_path)
    elif 'MultiBackbonePROVABGSModel' in model_name:
        model = MultiBackbonePROVABGSModel.load_from_checkpoint(model_path)
    else:
        raise ValueError(f"Model {model_name} not implemented in eval script.")
    
    model = model.eval()
    model = model.to('cuda')
    return model


def experiment_indomain(models_to_evaluate, overwrite: bool = False, version: str = '2'):
    """
    This experiment evaluates the adapted models on the task they were adapted for
    """
    # Loading catalog data for the legacysurvey evaluation
    catalog = Table.read(f'data/provabgs_legacysurvey_eval_v{version}.fits')

    output_dir = f'data/analysis/provabgs/indomain_oct24_v{version}'
    os.makedirs(output_dir, exist_ok=True)
    
    for run_id in models_to_evaluate['ID']:
        run = wandb_api.run(f"{entity_name}/{project_name}/{run_id}")

        print(f"Processing run {run_id}")
        # Check if the file has already been processed, if so, skip it
        if os.path.exists(f'{output_dir}/{run.name}_{run_id}.fits') and not overwrite:
            print(f"Run {run_id} already processed, skipping")
            continue

        # Load the model 
        model = get_model(run, run_id)

        # Process the catalog, providing the outputs the model expects
        predictions = batch_process_catalog(catalog, model, run.config['data']['init_args']['input_fields'])

        # Save the predictions        
        predictions.write(f'{output_dir}/{run.name}_{run_id}.fits', overwrite=True)


def main(args):
    # Load the CSV file containing the run IDs
    models_to_evaluate = pd.read_csv(args.wandb_csv_file)

    # Run the experiment
    experiment_indomain(models_to_evaluate, args.overwrite, args.version)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run model inference on a catalog of data.")
    parser.add_argument('--wandb_csv_file', type=str, default='scripts/provabgs_runs_oct24_phot.csv', help="csv file of all the runs we want to analyse.")
    parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files if true.")
    parser.add_argument('--version', type=str, default='2', help="Version of the catalog to use.")

    args = parser.parse_args()
    main(args)