import torch
import numpy as np

from tqdm import tqdm
from argparse import ArgumentParser
from astropy.table import Table
from aion.model import AION


def embed_data(dataset, model,  batch_size=256, num_encoder_tokens=576):
    """Embed data using the given model."""
    num_encoder_tokens = 576

    # Iterative over the dataset
    num_samples = len(dataset)
    num_batches = (num_samples + batch_size - 1) // batch_size

    embeddings_ls = []
    embeddings_hsc = []
    for i in tqdm(range(num_batches), desc="Embedding data"):
        batch_ls = {
            'tok_image': torch.tensor(np.array(dataset['tok_image'][i*batch_size:(i+1)*batch_size], dtype=np.float32)).cuda().reshape(-1,24*24)
        }
        batch_hsc = {
            'tok_image_hsc': torch.tensor(np.array(dataset['tok_image_hsc'][i*batch_size:(i+1)*batch_size], dtype=np.float32)).cuda().reshape(-1,24*24)
        }

        with torch.no_grad():
            encoded = model.encode(batch_ls, num_encoder_tokens=num_encoder_tokens)
            encoded_mean = torch.mean(encoded, dim=1)
            embeddings_ls.append(encoded_mean.cpu().numpy())

            encoded = model.encode(batch_hsc, num_encoder_tokens=num_encoder_tokens)
            encoded_mean = torch.mean(encoded, dim=1)
            embeddings_hsc.append(encoded_mean.cpu().numpy())

    return np.concatenate(embeddings_ls, axis=0), np.concatenate(embeddings_hsc, axis=0)


def main(dset_path, model_path, save_path, batch_size=256):
    # Set up dataset
    df = Table.read(dset_path)

    # Set up model
    model = AION.from_pretrained(model_path)
    model = torch.compile(model).eval().cuda() 

    # Embed data
    embeddings_ls, embeddings_hsc = embed_data(df, model, batch_size=batch_size)

    # Save the dataset
    df['embeddings_ls'] = embeddings_ls
    df['embeddings_hsc'] = embeddings_hsc
    df.write(save_path, overwrite=True)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dset_path', type=str, default='data/lens_parent_sample_v1.fits')
    parser.add_argument('--model_path', type=str, default='data/aion/oct24/base')
    parser.add_argument('--save_path', type=str, default='data/lens_parent_sample_v1_embedded_oct24_base.hdf5')
    parser.add_argument('--batch_size', type=int, default=256)
    args = parser.parse_args()

    main(args.dset_path, args.model_path, args.save_path, args.batch_size)
    