from vit_prisma.models.base_vit import HookedViT
from huggingface_hub import hf_hub_download
from vit_prisma.sae.sae import SparseAutoencoder
import torch
import os, json, math, collections, random
from   pathlib import Path
from   typing  import List, Dict

import torch, einops, numpy as np, pandas as pd
from   torch.utils.data import DataLoader, Dataset
from   torchvision.datasets import DatasetFolder
from   torchvision.datasets.folder import default_loader
from   tqdm.auto import tqdm

from vit_prisma.models.base_vit import HookedViT
from vit_prisma.sae import StandardSparseAutoencoder
from vit_prisma.sae.evals.evals import get_text_embeddings_openclip
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_index_to_name
from vit_prisma.transforms.model_transforms  import get_clip_val_transforms
import open_clip


from vit_prisma.transforms.model_transforms import get_clip_val_transforms
from torch.utils.data.dataloader import DataLoader

data_transforms = get_clip_val_transforms(None)

from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader

import torch

class IndexedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
            
    def __getitem__(self, index):
        image, label = self.dataset[index]
        return image, label, index

    def __len__(self):
        return len(self.dataset)


val_ds = IndexedDataset(DatasetFolder(
            root='IMGNET_VAL_PATH',
            loader=default_loader,
            extensions=('.jpg', '.jpeg', '.png'),
            transform=data_transforms
))

val_dl = DataLoader(val_ds, batch_size=32, shuffle=True, num_workers=4)

OUT_DIR = "FEATURE_DICTIONARY_PATH"

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--sae_path', type=str, required=True)
args = parser.parse_args()

SAE_ID_LIST = [args.sae_path]

NUM_WORKERS = 4
BATCH_SIZE_IMG = 32
TOP_K_IMG = 16          
TOP_K_CLIP = 5          
FEAT_PER_BIN = 100         
PURITY_THRESH = 0.30
device = torch.device("cuda")
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True

imagenet_index  = get_imagenet_index_to_name()
batch_labels = [imagenet_index[str(i)][1] for i in range(1000)]

clip_model, _, preproc = open_clip.create_model_and_transforms(
    "hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K")
clip_model = clip_model.to(device).eval()
tokenizer = open_clip.get_tokenizer("hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K")

with torch.no_grad():
    text_emb = get_text_embeddings_openclip(
        clip_model, preproc, tokenizer, batch_labels
    ).to(device)                       

vit = HookedViT.from_pretrained(
    "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K", fold_ln=False
).to(device).eval()

@torch.no_grad()
def process_dataset(dl, sae):
    total_acts, total_tok = None, 0
    hook_pt = sae.cfg.hook_point
    for imgs,_,_ in tqdm(dl, desc="pass‑1 freq", leave=False):
        imgs = imgs.to(device, non_blocking=True)
        acts = vit.run_with_cache(imgs, names_filter=[hook_pt])[1][hook_pt][:,1:,:]
        _, feats, *_ = sae(acts)                 
        bin_ = (feats.abs() > 0).float().flatten(0,1) 
        total_acts = bin_.mean(0) if total_acts is None else total_acts + bin_.mean(0)
        total_tok  += bin_.shape[0]
    return total_acts, total_tok

@torch.no_grad()
def feature_acts_batch(imgs, sae, enc_w, enc_b):
    z = vit.run_with_cache(
        imgs, names_filter=[sae.cfg.hook_point])[1][sae.cfg.hook_point]
    z = z - sae.b_dec
    z = torch.einsum("b s d, d n -> b s n", z, enc_w) + enc_b
    return torch.relu(z)

@torch.no_grad()
def topk_per_feature(dl, sae, feat_ids, cat_lbls, k):
    enc_w = sae.W_enc[:, feat_ids].to(device)
    enc_b = sae.b_enc[feat_ids].to(device)
    n_f   = len(feat_ids)
    top_v = torch.full((n_f,k), -torch.inf, device=device)
    top_i = torch.full((n_f,k), -1, device=device)

    is_cls = torch.tensor([ "CLS_" in c for c in cat_lbls ], device=device)

    for imgs,_,idxs in tqdm(dl, desc="pass‑2 top‑k", leave=False):
        imgs = imgs.to(device, non_blocking=True)
        idxs = idxs.to(device, non_blocking=True)
        z    = feature_acts_batch(imgs, sae, enc_w, enc_b) 

        pooled = torch.where(is_cls[None,:], z[:,0], z.mean(1))

        for j in range(n_f):
            cv  = torch.cat([top_v[j], pooled[:,j]])
            ci  = torch.cat([top_i[j], idxs       ])
            v,o = cv.topk(k)
            top_v[j], top_i[j] = v, ci[o]
    return {fid:(top_v[i].cpu(), top_i[i].cpu())
            for i,fid in enumerate(feat_ids)}

def build_pred_dict(topk, k_clip):
    uniq = torch.unique(torch.cat([v[1] for v in topk.values()])).cpu()
    img_map = {int(i):n for n,i in enumerate(uniq.tolist())}

    batch = torch.stack([
        preproc(default_loader(val_ds.dataset.samples[i][0]))
        for i in uniq]).to(device)

    with torch.no_grad():
        im_emb = clip_model.encode_image(batch)
        im_emb /= im_emb.norm(dim=-1, keepdim=True)
        sims   = im_emb @ text_emb.T
        probs,l_idx = sims.softmax(-1).topk(k_clip, dim=-1)

    pred, coeff = collections.defaultdict(list), {}
    for feat,(vals,idxs) in topk.items():
        mapped = torch.tensor([img_map[int(i)] for i in idxs])
        for r,(v,m) in enumerate(zip(vals,mapped)):
            pred[feat].append(dict(
                rank    = int(r),
                act_val = float(v),
                img_idx = int(uniq[m]),
                topk    = [(batch_labels[int(i)], float(p))
                           for i,p in zip(l_idx[m].tolist(), probs[m].tolist())]
            ))
        coeff[feat] = dict(max_coeff=float(vals.max()),
                           mean_coeff=float(vals.mean()))
    return pred, coeff

def majority_label(topk): return topk[0][0]

def semantic_purity_score(topk_list, label_to_idx, text_emb):
    majority_label_str = majority_label([entry["topk"] for entry in topk_list])
    maj_idx = label_to_idx.get(majority_label_str, -1)
    if maj_idx == -1:
        return 0.0

    maj_emb = text_emb[maj_idx]  # (d,)
    maj_emb = maj_emb / maj_emb.norm()

    pred_labels = [entry["topk"][0][0] for entry in topk_list]
    pred_idxs = [label_to_idx.get(lbl, -1) for lbl in pred_labels]
    pred_idxs = [i for i in pred_idxs if i != -1]

    if not pred_idxs:
        return 0.0  

    pred_embs = text_emb[pred_idxs]
    pred_embs = pred_embs / pred_embs.norm(dim=-1, keepdim=True)

    sims = (pred_embs @ maj_emb).squeeze()
    return sims.mean().item()


def stats(entries):
    labels=[majority_label(e["topk"]) for e in entries]
    N=len(labels); cnt=collections.Counter(labels)
    maj,mc=cnt.most_common(1)[0]; pur=mc/N
    probs=[c/N for c in cnt.values()]
    ent=-(sum(p*math.log(p) for p in probs))/(math.log(len(cnt)) + 0.000001)
    return dict(majority_label=maj, purity=pur, entropy=ent, n_images=N)

label_to_idx = {lbl: i for i, lbl in enumerate(batch_labels)}


def agg(pred, coeff):
    rows=[]
    for fid,ent in pred.items():
        row=stats(ent); row["feature_id"]=int(fid); row.update(coeff[fid]); rows.append(row)
    df=(pd.DataFrame(rows).set_index("feature_id")
         .sort_values("purity", ascending=False))
    interp=df[df.purity>=PURITY_THRESH]
    summ=dict(
        n_features=len(df),
        purity_threshold=PURITY_THRESH,
        n_interpretable=len(interp),
        frac_interpretable=len(interp)/len(df),
        n_unique_labels=interp.majority_label.nunique(),
        uniqueness_fraction=interp.majority_label.nunique()/len(interp)
                           if len(interp) else 0.0)
    return df,summ

os.makedirs(OUT_DIR, exist_ok=True)
random.seed(0); torch.manual_seed(0)

for sae_id in SAE_ID_LIST:
    sae = SparseAutoencoder.load_from_pretrained(sae_id)
    layer = sae.cfg.hook_point_layer
    print(f"\n==========  SAE layer {layer}  ==========\n")

    acts,tok = process_dataset(val_dl, sae)
    logf     = torch.log10(acts/tok).to(device)
    bins=[(-8,-6),(-6,-5),(-5,-4),(-4,-3),(-3,-2),(-2,-1),(-float("inf"),-8),(-1,float("inf"))]
    btxt=[f"TOTAL_logfreq_[{l},{u}]" for l,u in bins]; btxt[-2]=btxt[-2].replace("-inf","-∞"); btxt[-1]=btxt[-1].replace("inf","∞")
    feat_ids, feat_cats = [], []
    for (lo,hi),lbl in zip(bins,btxt):
        mask=(logf>=lo)&(logf<hi)
        idxs=torch.nonzero(mask,as_tuple=True)[0]
        if len(idxs)==0: continue
        sel=idxs[torch.randperm(len(idxs))[:FEAT_PER_BIN]].tolist()
        feat_ids.extend(sel); feat_cats.extend([lbl]*len(sel))

    topk = topk_per_feature(val_dl, sae, feat_ids, feat_cats, TOP_K_IMG)
    pred, coeff = build_pred_dict(topk, TOP_K_CLIP)

    df, summ = agg(pred, coeff)

    majority_to_idx={lbl:i for i,lbl in enumerate(batch_labels)}
    out_json = Path(OUT_DIR)/f"clip_predictions_{layer}.json"
    with open(out_json,"w") as f: json.dump(pred, f, indent=2)
    out_csv  = Path(OUT_DIR)/f"results_{layer}.csv"

    df.to_csv(out_csv)