import os
import torch
import torchvision
import open_clip
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, DistributedSampler
import webdataset as wds
from pathlib import Path

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import time


dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

shards = [str(p) for p in Path("/scratch/svillalo/imagenet21k/wds_streamed").rglob("*.tar")]
print(f"Found {len(shards)} shards")

# --- CONFIG ---
# DATA_PATH = "/scratch/svillalo/imagenet21k/extracted"
MODEL_NAME = "ViT-g-14"
PRETRAINED = "laion2b_s34b_b88k"
BATCH_SIZE = 32  # Increased batch size for CPU
NUM_WORKERS = 16  # Match number of cores
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_FILE = "ImageNet21k_clip_embeddings.pt"
CACHE_DIR = "/scratch/svillalo/hf_cache"

print(f"device = {DEVICE}, rank = {local_rank}")

DEVICE = dist.get_rank() % torch.accelerator.device_count()
print(f"getrank = {dist.get_rank()}, DEVICE = {DEVICE}")

# --- SET THREADS ---
if DEVICE == "cpu":
    torch.set_num_threads(NUM_WORKERS)
    print(f"Using {torch.get_num_threads()} CPU threads")

# --- LOAD MODEL ---
model, preprocess, _ = open_clip.create_model_and_transforms(
    MODEL_NAME, pretrained=PRETRAINED, cache_dir=CACHE_DIR
)
model = model.to(DEVICE)
model = DDP(model, device_ids=[local_rank])
model.eval()

# --- DATASET WRAPPER ---
# class PreprocessedDataset(Dataset):
#     def __init__(self, root_dir, transform):
#         self.ds = torchvision.datasets.ImageFolder(root=root_dir)
#         self.transform = transform

#     def __len__(self):
#         return len(self.ds)

#     def __getitem__(self, idx):
#         img, label = self.ds[idx]
#         return self.transform(img), label

# dataset = PreprocessedDataset(DATA_PATH, preprocess)

classes = sorted([p.name for p in Path("/scratch/svillalo/imagenet21k/wds_streamed").iterdir() if p.is_dir()])
wnid_to_idx = {cls: i for i, cls in enumerate(classes)}

# def wnid_to_int(label_bytes):
#     wnid = label_bytes.decode("utf-8")
#     return wnid_to_idx[wnid]

# def decode_with_str_labels(sample):
#     sample["jpg"] = wds.imagehandler("pil")(sample["jpg"])
#     sample["cls"] = sample["cls"].decode("utf-8")  # decode bytes to string
#     return sample

# dataset = (
#     wds.WebDataset(shards, shardshuffle=False)
#     .map(decode_with_str_labels)
#     .to_tuple("jpg", "cls")
#     .map_tuple(preprocess, lambda x: x)
# )

# def wds_preprocess(sample):
#     image, wnid = sample
#     return preprocess(image), wnid_to_idx[wnid]

dataset = (
    wds.WebDataset(shards, shardshuffle=False, handler=wds.warn_and_continue)
    .decode("pil", wds.handle_extension(".cls", lambda x: x.decode()), handler=wds.warn_and_continue)
    # .map_dict(cls=wnid_to_int)
    .to_tuple("jpg", "cls")
    .map_tuple(preprocess, lambda x: torch.tensor(wnid_to_idx[x], dtype=torch.int32))
)

# dataset = dataset.to_tuple("jpg", "txt")
# dataset = dataset.map_tuple(preprocess, lambda x: x)

sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    sampler=sampler,
)

# --- EMBEDDING SAVE FUNCTION ---
def save_embeddings_chunk(chunk_idx, embeddings_list, labels_list, folder="chunks"):
    os.makedirs(folder, exist_ok=True)
    torch.save({
        "embeddings": torch.cat(embeddings_list),
        "labels": torch.cat(labels_list)
    }, os.path.join(folder, f"embeddings_rank_{local_rank}_chunk_{chunk_idx}.pt"))
    embeddings_list.clear()
    labels_list.clear()

# --- ENCODE ---
all_embs, all_labels = [], []
curr_len = 0
chunk_size = 10000  # Save every 10k images to reduce RAM pressure
chunk_idx = 0

# print("loader len: ", len(loader))

with torch.no_grad():
    # autocast_ctx = torch.amp.autocast if DEVICE == "cuda" else torch.no_grad
    with torch.amp.autocast(DEVICE):
        for imgs, labels in tqdm(loader):
            # t0 = time.time()
            imgs = imgs.to(DEVICE, non_blocking=True)
            emb = model.encode_image(imgs)
            # print("Batch time:", time.time() - t0)
            emb = emb / emb.norm(dim=-1, keepdim=True)
            
            all_embs.append(emb.cpu())
            all_labels.append(labels)
            curr_len += len(labels)

            # Save in chunks to avoid huge RAM usage
            if curr_len >= chunk_size:
                save_embeddings_chunk(chunk_idx, all_embs, all_labels)
                curr_len = 0
                chunk_idx += 1

# Save any remaining embeddings
if all_embs:
    save_embeddings_chunk(chunk_idx, all_embs, all_labels)

# # --- MERGE CHUNKS ---
# print("Merging chunks...")
# chunk_files = sorted([f for f in os.listdir("chunks") if f.endswith(".pt")])
# all_embs, all_labels = [], []
# for f in tqdm(chunk_files):
#     data = torch.load(os.path.join("chunks", f))
#     all_embs.append(data["embeddings"])
#     all_labels.append(data["labels"])

# embs = torch.cat(all_embs)
# labels = torch.cat(all_labels)

# # --- FINAL SAVE ---
# torch.save({
#     "embeddings": embs,
#     "labels": labels,
#     "model": f"{MODEL_NAME}-{PRETRAINED}",
#     "dataset": "ImageNet21k"
# }, OUTPUT_FILE)

# print(f"Saved final embeddings: {embs.shape}")

dist.destroy_process_group()