import torch
import numpy as np
from torchvision import transforms
from collections import defaultdict
from tqdm import tqdm
import clip
from PIL import Image

def extract_clip_features_from_subset(
    dataset_subset,
    batch_size=64,
    model_name="ViT-B/32",
    device=None
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model, _ = clip.load(model_name, device=device, download_root="./CLIP")
    model.eval()

    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    features_by_class = defaultdict(list)

    for i in tqdm(range(0, len(dataset_subset), batch_size), desc="Extracting CLIP features"):
        batch = [dataset_subset[j] for j in range(i, min(i + batch_size, len(dataset_subset)))]

        batch_imgs = []
        labels = []

        for img, label in batch:
            if not isinstance(img, Image.Image):
                img = Image.fromarray(np.array(img))

            img = img.resize((32, 32), Image.BICUBIC)
            img = img.resize((224, 224), Image.BICUBIC)

            img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
            img_tensor = (img_tensor - mean) / std
            batch_imgs.append(img_tensor)
            labels.append(label)

        images_tensor = torch.stack(batch_imgs).to(device)

        with torch.no_grad():
            feats = model.encode_image(images_tensor).cpu().numpy()

        for feat, label in zip(feats, labels):
            features_by_class[str(label)].append(feat)

    features_by_class = {k: np.stack(v) for k, v in features_by_class.items()}
    return features_by_class

def extract_dino_features_from_subset(
    dataset_subset,
    model_name='dinov2_vitb14',
    batch_size=64,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = torch.hub.load('<dinov2 path>', model_name).to(device).eval()

    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    features_by_class = defaultdict(list)

    for i in tqdm(range(0, len(dataset_subset), batch_size), desc="Extracting DINOv2 features"):
        batch = [dataset_subset[j] for j in range(i, min(i + batch_size, len(dataset_subset)))]

        batch_imgs = []
        labels = []

        for img, label in batch:

            if not isinstance(img, Image.Image):
                img = Image.fromarray(np.array(img))

            img = img.resize((32, 32), Image.BICUBIC)
            img = img.resize((224, 224), Image.BICUBIC)
            img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
            img_tensor = (img_tensor - mean) / std
            batch_imgs.append(img_tensor)
            labels.append(label)

        images_tensor = torch.stack(batch_imgs).to(device)

        with torch.no_grad():
            feats = model(images_tensor).cpu().numpy()

        for feat, label in zip(feats, labels):
            features_by_class[str(label)].append(feat)

    features_by_class = {k: np.stack(v) for k, v in features_by_class.items()}
    return features_by_class

def extract_clip_text_features(texts, model_name="ViT-B/32", device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model, preprocess = clip.load(model_name, device=device, download_root="./CLIP")
    model.eval()

    if isinstance(texts, str):
        texts = [texts]

    with torch.no_grad():
        tokens = clip.tokenize(texts).to(device)
        text_features = model.encode_text(tokens).cpu().numpy()

    return text_features if len(texts) > 1 else text_features[0]

from collections import defaultdict
from tqdm import tqdm
import torch.nn as nn
import numpy as np
from PIL import Image
from collections import defaultdict
from tqdm import tqdm

def extract_clip_features_from_subset_imagenet(
    dataset_subset,
    batch_size=64,
    model_name="ViT-B/32",
    device=None
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model, _ = clip.load(model_name, device=device, download_root="./CLIP")
    model.eval()

    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    features_by_class = defaultdict(list)

    for i in tqdm(range(0, len(dataset_subset), batch_size), desc="Extracting CLIP features"):
        batch = [dataset_subset[j] for j in range(i, min(i + batch_size, len(dataset_subset)))]

        batch_imgs = []
        labels = []

        for img, label in batch:
            if not isinstance(img, Image.Image):
                img = Image.fromarray(np.array(img))

            img = img.resize((224, 224), Image.BICUBIC)

            img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
            img_tensor = (img_tensor - mean) / std
            batch_imgs.append(img_tensor)
            labels.append(label)

        images_tensor = torch.stack(batch_imgs).to(device)

        with torch.no_grad():
            feats = model.encode_image(images_tensor).cpu().numpy()

        for feat, label in zip(feats, labels):
            features_by_class[str(label)].append(feat)

    features_by_class = {k: np.stack(v) for k, v in features_by_class.items()}
    return features_by_class

def extract_dino_features_from_subset_imagenet(
    dataset_subset,
    model_name='dinov2_vitb14',
    batch_size=64,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = torch.hub.load('<dino path>', model_name).to(device).eval()

    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    features_by_class = defaultdict(list)

    for i in tqdm(range(0, len(dataset_subset), batch_size), desc="Extracting DINOv2 features"):
        batch = [dataset_subset[j] for j in range(i, min(i + batch_size, len(dataset_subset)))]

        batch_imgs = []
        labels = []

        for img, label in batch:

            if not isinstance(img, Image.Image):
                img = Image.fromarray(np.array(img))

            img = img.resize((224, 224), Image.BICUBIC)
            img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
            img_tensor = (img_tensor - mean) / std
            batch_imgs.append(img_tensor)
            labels.append(label)

        images_tensor = torch.stack(batch_imgs).to(device)

        with torch.no_grad():
            feats = model(images_tensor).cpu().numpy()

        for feat, label in zip(feats, labels):
            features_by_class[str(label)].append(feat)

    features_by_class = {k: np.stack(v) for k, v in features_by_class.items()}
    return features_by_class

def get_imagenet100_features():
    classes = ["tailed frog",
            "bee eater",
            "European green lizard",
            "limpkin",
            "loggerhead sea turtle",
            "kingsnake",
            "American dipper",
            "desert grassland whiptail lizard",
            "wombat",
            "sea snake",
            "duck",
            "spotted salamander",
            "great grey owl",
            "hornbill",
            "white stork",
            "stingray",
            "spoonbill",
            "bittern",
            "snail",
            "magpie",
            "sidewinder",
            "mud turtle",
            "bald eagle",
            "bulbul",
            "chambered nautilus",
            "peacock",
            "sulphur-crested cockatoo",
            "prairie grouse",
            "scorpion",
            "Dungeness crab",
            "tarantula",
            "conch",
            "eastern hog-nosed snake",
            "boa constrictor",
            "tiger shark",
            "smooth newt",
            "goose",
            "wallaby",
            "tick",
            "terrapin",
            "pelican",
            "electric ray",
            "banded gecko",
            "jellyfish",
            "vine snake",
            "ptarmigan",
            "black swan",
            "crayfish",
            "eastern diamondback rattlesnake",
            "hermit crab",
            "worm snake",
            "southern black widow",
            "hen",
            "toucan",
            "chiton",
            "kite",
            "dunlin",
            "barn spider",
            "goldfish",
            "smooth green snake",
            "indigo bunting",
            "sea lion",
            "American coot",
            "American alligator",
            "cock",
            "garter snake",
            "European garden spider",
            "albatross",
            "chickadee",
            "night snake",
            "green iguana",
            "Saharan horned viper",
            "yellow garden spider",
            "common redshank",
            "harvestman",
            "leatherback sea turtle",
            "tench",
            "wolf spider",
            "agama",
            "green mamba",
            "hummingbird",
            "goldfinch",
            "spiny lobster",
            "hammerhead shark",
            "black grouse",
            "lorikeet",
            "sea slug",
            "rock crab",
            "flatworm",
            "nematode",
            "coucal",
            "Komodo dragon",
            "macaw",
            "great white shark",
            "bustard",
            "crane (bird)",
            "flamingo",
            "axolotl",
            "oystercatcher",
            "sea anemone"]
    full_prompts = [f"a photo of a {cls}." for cls in classes]
    features = extract_clip_text_features(full_prompts, device="cuda")
    return features
