import torch

CIFAR_10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
CIFAR_100_CLASSES = ['apple', 'aquarium fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak tree', 'orange', 'orchid', 'otter', 'palm tree', 'pear', 'pickup truck', 'pine tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow tree', 'wolf', 'woman', 'worm']
TINY_IMAGENET_CLASSES = ['goldfish', 'European fire salamander', 'bullfrog', 'tailed frog', 'American alligator', 'boa constrictor', 'trilobite', 'scorpion', 'black widow', 'tarantula', 'centipede', 'goose', 'koala', 'jellyfish', 'brain coral', 'snail', 'slug', 'sea slug', 'American lobster', 'spiny lobster', 'black stork', 'king penguin', 'albatross', 'dugong', 'Chihuahua', 'Yorkshire terrier', 'golden retriever', 'Labrador retriever', 'German shepherd', 'standard poodle', 'tabby', 'Persian cat', 'Egyptian cat', 'cougar', 'lion', 'brown bear', 'ladybug', 'fly', 'bee', 'grasshopper', 'walking stick', 'cockroach', 'mantis', 'dragonfly', 'monarch', 'sulphur butterfly', 'sea cucumber', 'guinea pig', 'hog', 'ox', 'bison', 'bighorn', 'gazelle', 'Arabian camel', 'orangutan', 'chimpanzee', 'baboon', 'African elephant', 'lesser panda', 'abacus', 'academic gown', 'altar', 'apron', 'backpack', 'bannister', 'barbershop', 'barn', 'barrel', 'basketball', 'bathtub', 'beach wagon', 'beacon', 'beaker', 'beer bottle', 'bikini', 'binoculars', 'birdhouse', 'bow tie', 'brass', 'broom', 'bucket', 'bullet train', 'butcher shop', 'candle', 'cannon', 'cardigan', 'cash machine', 'CD player', 'chain', 'chest', 'Christmas stocking', 'cliff dwelling', 'computer keyboard', 'confectionery', 'convertible', 'crane', 'dam', 'desk', 'dining table', 'drumstick', 'dumbbell', 'flagpole', 'fountain', 'freight car', 'frying pan', 'fur coat', 'gasmask', 'go-kart', 'gondola', 'hourglass', 'iPod', 'jinrikisha', 'kimono', 'lampshade', 'lawn mower', 'lifeboat', 'limousine', 'magnetic compass', 'maypole', 'military uniform', 'miniskirt', 'moving van', 'nail', 'neck brace', 'obelisk', 'oboe', 'organ', 'parking meter', 'pay-phone', 'picket fence', 'pill bottle', 'plunger', 'pole', 'police van', 'poncho', 'pop bottle', "potter's wheel", 'projectile', 'punching bag', 'reel', 'refrigerator', 'remote control', 'rocking chair', 'rugby ball', 'sandal', 'school bus', 'scoreboard', 'sewing machine', 'snorkel', 'sock', 'sombrero', 'space heater', 'spider web', 'sports car', 'steel arch bridge', 'stopwatch', 'sunglasses', 'suspension bridge', 'swimming trunks', 'syringe', 'teapot', 'teddy', 'thatch', 'torch', 'tractor', 'triumphal arch', 'trolleybus', 'turnstile', 'umbrella', 'vestment', 'viaduct', 'volleyball', 'water jug', 'water tower', 'wok', 'wooden spoon', 'comic book', 'plate', 'guacamole', 'ice cream', 'ice lolly', 'pretzel', 'mashed potato', 'cauliflower', 'bell pepper', 'mushroom', 'orange', 'lemon', 'banana', 'pomegranate', 'meat loaf', 'pizza', 'potpie', 'espresso', 'alp', 'cliff', 'coral reef', 'lakeside', 'seashore', 'acorn']
PROMPTS = [
    'a photo of a {class}.',
    'a blurry photo of a {class}.',
    'a black and white photo of a {class}.',
    'a high contrast photo of a {class}.',
    'a good photo of a {class}.',
    'a photo of a small {class}.',
    'a photo of a big {class}.',
]

import torchvision.transforms as transforms

def get_transform(dataset: str):
    if dataset == 'cifar-10' or dataset == 'cifar-100':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    elif dataset == 'tiny-imagenet':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        ])
    else:
        raise NotImplementedError
    return transform

def check_tensor(name, tensor):
    if not torch.isfinite(tensor).all():
        print(f"{name} has NaN/Inf, "
                      f"min={tensor.min().item()}, max={tensor.max().item()}")
        return False
    return True

def _get_class_name(dataset:str, label: int):
    if dataset == 'cifar-10':
        return CIFAR_10_CLASSES[label]
    elif dataset == 'cifar-100':
        return CIFAR_100_CLASSES[label]
    elif dataset == 'tiny-imagenet':
        return TINY_IMAGENET_CLASSES[label]
    else:
        raise NotImplementedError

def _set_trainable_adapter(pipe, adapter_name: str):
    for n, p in pipe.unet.named_parameters():
        if "lora" in n:
            p.requires_grad = (adapter_name in n)

def _adapter_params(pipe, adapter_name: str):
    return [p for n, p in pipe.unet.named_parameters()
            if p.requires_grad and ("lora" in n) and (adapter_name in n)]
