import wandb
import torch
import inspect
from dataclasses import fields

from vit_prisma.sae.config import VisionModelSAERunnerConfig
from vit_prisma.sae.train_sae import VisionSAETrainer
from vit_prisma.models.base_vit import HookedViT

from noise_dataset import RandomImageDataset

def train():

    wandb.login(key="<WANDB_API_KEY>")
    for layer in range(0, 12):
        run = wandb.init()
        
        cfg = VisionModelSAERunnerConfig(hook_point_layer=layer)

        print(f'Training Noise SAE for {cfg.hook_point_layer}')

        cfg.__post_init__()

        # Log the full configuration
        full_config = {}
        for field in fields(cfg):
            value = getattr(cfg, field.name)
            if isinstance(value, torch.dtype):
                value = str(value).split('.')[-1]  # Convert torch.dtype to string
            elif isinstance(value, torch.device):
                value = str(value)  # Convert torch.device to string
            elif callable(value) or inspect.isclass(value):
                continue  # Skip methods and classes
            full_config[field.name] = value

        wandb.config.update(full_config)

        print("Configuration:")
        print(full_config)

        model = HookedViT.from_pretrained(cfg.model_name)
        train_data, val_data = RandomImageDataset(), RandomImageDataset(num_images=50_000)

        trainer = VisionSAETrainer(cfg, model, train_data, val_data)
        sae = trainer.run()

        wandb.finish()

if __name__ == "__main__":
    train()