import torch
import numpy as np

from tqdm import tqdm
from argparse import ArgumentParser
from astropy.table import Table
from aion.model import AION


def embed_gz_data(gz_dataset, model, batch_size=256, num_encoder_tokens=576):
    """Embed Galaxy Zoo data using the given model."""
    num_encoder_tokens = 576

    # Iterative over the dataset
    num_samples = len(gz_dataset)
    num_batches = (num_samples + batch_size - 1) // batch_size

    embeddings = []
    for i in tqdm(range(num_batches), desc="Embedding Galaxy Zoo data"):
        batch = {
            'tok_image': torch.tensor(np.array(gz_dataset['tok_image'][i*batch_size:(i+1)*batch_size], dtype=np.float32)).cuda().reshape(-1,24*24)
        }

        with torch.no_grad():
            encoded = model.encode(batch, num_encoder_tokens=num_encoder_tokens)
            encoded_mean = torch.mean(encoded, dim=1)

        embeddings.append(encoded_mean.cpu().numpy())

    return np.concatenate(embeddings, axis=0)


def main(dset_path, model_path, save_path, batch_size=256):
    # Set up dataset
    df = Table.read(dset_path)

    # Set up model
    model = AION.from_pretrained(model_path)
    model.freeze_encoder(), model.freeze_decoder()
    model = torch.compile(model)
    model.cuda()

    # Embed data
    embeddings = embed_gz_data(df, model, batch_size=batch_size)

    # Save the dataset
    df['embeddings'] = embeddings
    df.write(save_path, overwrite=True)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dset_path', type=str, default='./data/gz5_legacysurvey_matches_mp.hdf5')
    parser.add_argument('--model_path', type=str, default='data/aion/dec24/large')
    parser.add_argument('--save_path', type=str, default='./data/gz5_large_embedded.hdf5')
    parser.add_argument('--batch_size', type=int, default=256)
    args = parser.parse_args()

    main(args.dset_path, args.model_path, args.save_path, args.batch_size)
    