import wandb
import numpy as np
import torch

class CompletionsTableLogger:
    def __init__(self, stego_trainer):
        self.stego_trainer = stego_trainer
        self.table_columns = {}

    def log_table(self, batch, prefix, columns=None):
        if prefix not in self.table_columns:
            # New prefix, create a table using all keys in the batch
            if columns is None:
                columns = list(batch.keys())
            self.table_columns[prefix] = columns
        else:
            # Existing prefix, use saved keys as columns
            columns = self.table_columns[prefix]

        completions = wandb.Table(columns=columns)
        num_entries = len(next(iter(batch.values())))

        for i in range(num_entries):
            row = []
            for key in columns:
                if key in batch:
                    if isinstance(batch[key], list) or isinstance(batch[key], torch.Tensor):
                        value = batch[key][i]
                    else:
                        value = batch[key]
                    row.append(value)
                else:
                    row.append(np.nan)  # Log NaN for missing keys
            completions.add_data(*row)

        # Log to wandb
        wandb_dict = {f"{prefix}_completions": completions}
        self.stego_trainer.do_wandb_log(wandb_dict)