import json, random, collections, functools
from pathlib import Path
from typing import Sequence, Tuple

import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm

from vit_prisma.transforms.model_transforms import get_clip_val_transforms
from torch.utils.data.dataloader import DataLoader
from vit_prisma.sae.evals.evals import get_text_embeddings_openclip
from vit_prisma.dataloaders.imagenet_index import imagenet_index
from transforms.model_transforms import get_model_transforms

import open_clip

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from vit_prisma.sae.sae import SparseAutoencoder
from vit_prisma.models.base_vit import HookedViT

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--sae_path', type=str, required=True)
args = parser.parse_args()

model = HookedViT.from_pretrained('open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K', fold_ln=False)
model = model.to('cuda')

sae = SparseAutoencoder.load_from_pretrained(args.sae_path)

num_imagenet_classes = 1000
batch_label_names = [imagenet_index[str(int(label))][1] for label in range(num_imagenet_classes)]

og_model, _, preproc = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K')
og_model = og_model.to('cuda')
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K')

text_embeddings = get_text_embeddings_openclip(og_model, preproc, tokenizer, batch_label_names)

data_transforms = get_clip_val_transforms(None)

from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader

import torch

class IndexedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
            
    def __getitem__(self, index):
        image, label = self.dataset[index]
        return image, label, index

    def __len__(self):
        return len(self.dataset)


val_set = IndexedDataset(DatasetFolder(
            root='IMGNET_VAL_PATH',
            loader=default_loader,
            extensions=('.jpg', '.jpeg', '.png'),
            transform=data_transforms
))

device   = "cuda" if torch.cuda.is_available() else "cpu"
model    = model.to(device).eval()          
sae      = sae.to(device).eval()
txt_emb  = text_embeddings.to(device)

def ensure_pil(x):
    """Return a PIL.Image regardless of whether x is PIL or tensor."""
    if torch.is_tensor(x):
        return to_pil_image(x.cpu().clamp(0, 1))
    return x

def majority_idx(record_list) -> int:
    c = collections.Counter(lbl for rec in record_list for lbl, _ in rec["topk"])
    maj_name, _ = c.most_common(1)[0]
    return batch_label_names.index(maj_name)

@torch.no_grad()
def steer_once(img_pil,
               feat_id: int,
               scale: float,
               maj_idx: int) -> Tuple[bool, float]:

    img = preproc(ensure_pil(img_pil)).unsqueeze(0).to(device)

    base_emb, _   = model.run_with_cache(img)
    base_logits   = base_emb @ txt_emb.T

    def inject(acts, hook, *, feat_id, scale):
        acts[:, 1:, :]  += scale * sae.W_dec[feat_id, :]

        return acts

    steered_emb = model.run_with_hooks(
        img,
        fwd_hooks=[(sae.cfg.hook_point,
                    functools.partial(inject, feat_id=feat_id, scale=scale))]
    )
    steered_logits = steered_emb @ txt_emb.T

    hit  = (base_logits.argmax(1) != maj_idx) & (steered_logits.argmax(1) == maj_idx)
    dlog = (steered_logits - base_logits)[0, maj_idx].item()
    return bool(hit.item()), dlog

def steerability_sweep(json_path: str,
                       val_ds,
                       n_trials: int = 32,
                       scales: Sequence[float] = list(range(0, 10))):

    data = json.loads(Path(json_path).read_text())
    results = {}

    for feat_id, recs in data.items():
        maj_idx  = majority_idx(recs)
        maj_name = batch_label_names[maj_idx]

        print(f"\n★ Feature {feat_id}   (majority ≈ {maj_name})")

        neutral = []
        while len(neutral) < n_trials:
            img, *_ = val_ds[random.randrange(len(val_ds))]
            img_pil = ensure_pil(img)
            with torch.no_grad():
                _, cache = model.run_with_cache(preproc(img_pil)
                                 .unsqueeze(0).to(device))
                sae_in = cache[sae.cfg.hook_point]
                _, acts, *_ = sae.encode(sae_in)
                acts = acts[:,1:,:] 
            if acts[0, 0, int(feat_id)] < 0.1:
                neutral.append(img_pil)

        feat_res = {}
        for s in scales:
            hits, dlogs = zip(*(steer_once(img, int(feat_id), s, maj_idx)
                                for img in neutral))
            feat_res[s] = dict(hit_rate=sum(hits)/n_trials,
                               delta_logit=sum(dlogs)/n_trials)
            print(f"  scale {s:<3}: hit-rate {feat_res[s]['hit_rate']*100:5.1f}%   "
                  f"delta logit {feat_res[s]['delta_logit']:+.3f}")
        results[feat_id] = dict(majority=maj_name, metrics=feat_res)
    return results

data = steerability_sweep(
    f"FEATURE_DICTIONARY_PATH/layer_{sae.cfg.hook_point_layer}.json",
    val_set
)

with open(f'OUTPUT_DIR/{sae.cfg.hook_point_layer}.json', 'w+') as f:
    json.dump(data, f)
