from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader
from torch.utils.data import DataLoader
import importlib

import transforms.model_transforms

importlib.reload(transforms.model_transforms)

from transforms.model_transforms import get_model_transforms

from vit_prisma.models.base_vit import HookedViT
from vit_prisma.sae.sae import SparseAutoencoder

import numpy as np
from tqdm import tqdm

model_name = 'open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K'
model = HookedViT.from_pretrained(model_name, fold_ln=False)
model = model.to('cuda')

ds_normal = DatasetFolder(
            root='IMGNET_VAL_PATH',
            loader=default_loader,
            extensions=('.jpg', '.jpeg', '.png'),
            transform=get_model_transforms(model.cfg.model_name, 'none')
)

ds_unstructured = DatasetFolder(
            root='IMGNET_VAL_PATH',
            loader=default_loader,
            extensions=('.jpg', '.jpeg', '.png'),
            transform=get_model_transforms(model.cfg.model_name, 'patch_shuffle')
)

dl_normal  = DataLoader(ds_normal,  batch_size=256, shuffle=False, num_workers=4)
dl_unstructured = DataLoader(ds_unstructured, batch_size=256, shuffle=False, num_workers=4)

layer_stats = {}

def get_avg_image_l0_per_batch(batch_tensor, sae):
    _, cache = model.run_with_cache(batch_tensor)
    hook_point_activation = cache[sae.cfg.hook_point]
    sae_input = hook_point_activation[:,1:,:]
    _, sae_latents, *_ = sae.encode(sae_input)

    non_zero_mask = sae_latents != 0 
    non_zero_per_patch = non_zero_mask.sum(dim=2)
    avg_non_zero_per_sample = non_zero_per_patch.float().mean(dim=1).mean().cpu().item()

    return avg_non_zero_per_sample

def collect_l0_stats(dataloader, sae):
    batch_l0s = []
    for images, _ in tqdm(dataloader):
        images = images.to('cuda')
        mean_l0_per_batch = get_avg_image_l0_per_batch(images, sae)
        batch_l0s.append(mean_l0_per_batch)
    l0_std = np.std(batch_l0s)
    return {
        'l0_mean' : np.mean(batch_l0s),
        'l0_std' : l0_std,
        'ci95': 1.96 * l0_std / np.sqrt(len(batch_l0s))
    }

sae_list = [
    # <SAE checkpoints>
]

for sae_id in sae_list:
    sae = SparseAutoencoder.load_from_pretrained(sae_id)
    stats = collect_l0_stats(dl_unstructured, sae)
    layer_stats[sae.cfg.hook_point_layer] = stats
    print(layer_stats)

import json

with open('OUTPUT_DIR/patch_shuffle_28_l10_avg_stats.json', 'w') as f:
    json.dump(layer_stats, f)
