import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
import re

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler, DDIMSampler1D
from ldm.models.diffusion.plms import PLMSSampler

from safetensors.torch import load_file, save_file


def remove_specific_classes(all_classes, specific_dataset_classes):
    # Convert specific classes list to a set for faster lookup
    specific_classes_set = set(specific_dataset_classes)
    
    # Filter all_classes to exclude those in specific_classes_set
    filtered_classes = [cls for cls in all_classes if cls not in specific_classes_set]
    
    return filtered_classes


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--config",
        type=str,
        nargs="?",
        default="",
        help="the config file"
    )

    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="",
        help="the prompt to render"
    )

    parser.add_argument(
        "--adapter",
        type=str,
        nargs="?",
        default="",
        help="the checkpoint file"
    )

    parser.add_argument(
        "--ctx",
        type=int,
        default=4,
        help="number of ctx length",
    )

    parser.add_argument(
        "--start_t",
        type=int,
        default=-1,
        help="start timestep",
    )

    parser.add_argument(
        "--negative_prompt",
        action='store_true',
        help="use neg sampling",
    )

    parser.add_argument(
        "--unconditional",
        action='store_true',
        help="use uncond",
    )

    parser.add_argument(
        "--class_file",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )

    parser.add_argument(
        "--ckpt",
        type=str,
        nargs="?",
        default="",
        help="the checkpoint file"
    )

    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
        default="outputs/txt2img-samples"
    )
    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=200,
        help="number of ddim sampling steps",
    )
    parser.add_argument(
        "--epoch_num",
        type=int,
        default=11,
        help="number to save ckpt",
    )

    parser.add_argument(
        "--plms",
        action='store_true',
        help="use plms sampling",
    )

    parser.add_argument(
        "--composable_diffusion",
        action='store_true',
        help="use composable",
    )

    parser.add_argument(
        "--negation_diffusion",
        action='store_true',
        help="use composable",
    )

    parser.add_argument(
        "--use_dsname",
        action='store_true',
        help="use dataset name",
    )

    parser.add_argument(
        "--avg_samples",
        action='store_true',
        help="use avg ctx",
    )

    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )

    parser.add_argument(
        "--image_path",
        type=str,
        nargs="?",
        default="",
        help="the prompt to render"
    )

    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )

    parser.add_argument(
        "--H",
        type=int,
        default=512,
        help="image height, in pixel space",
    )

    parser.add_argument(
        "--n_samples",
        type=int,
        default=1,
        help="how many samples to produce for the given prompt",
    )

    parser.add_argument(
        "--scale",
        type=float,
        default=5.0,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )
    opt = parser.parse_args()
    # seed_everything(opt.seed)


    config = OmegaConf.load(opt.config)  # TODO: Optionally download from same location as ckpt and chnage this logic
    model = load_model_from_config(config, opt.ckpt)  # TODO: check path

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

    datasets = ["oxford_flowers", "oxford_pets", "stanford_cars", "imagenet", "food101", "dtd", "eurosat", "ucf101", "sun397", "fgvc_aircraft", "caltech101"]

    datasets_paths = {
        "oxford_flowers": "coop_prompts_1d_autoencoder_kl_oxford_flowers/checkpoints/epoch=003029.ckpt",
        "oxford_pets": "coop_prompts_1d_autoencoder_kl_oxford_pets/checkpoints/epoch=002982.ckpt",
        "stanford_cars": "coop_prompts_1d_autoencoder_kl_stanford_cars/checkpoints/epoch=002986.ckpt",
        "imagenet": "coop_prompts_1d_autoencoder_kl_imagenet/checkpoints/epoch=009667.ckpt",
        "food101": "coop_prompts_1d_autoencoder_kl_food101/checkpoints/epoch=002979.ckpt",
        "dtd": "coop_prompts_1d_autoencoder_kl_dtd/checkpoints/epoch=003002.ckpt",
        "eurosat": "coop_prompts_1d_autoencoder_kl_eurosat/checkpoints/epoch=003187.ckpt",
        "ucf101": "coop_prompts_1d_autoencoder_kl_ucf101/checkpoints/epoch=002981.ckpt",
        "sun397": "coop_prompts_1d_autoencoder_kl_sun397/checkpoints/epoch=002993.ckpt",
        "fgvc_aircraft": "coop_prompts_1d_autoencoder_kl_fgvc/checkpoints/epoch=002746.ckpt",
        "caltech101": "coop_prompts_1d_autoencoder_kl_caltech101/checkpoints/epoch=003005.ckpt"
    }

    if opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)
    
    def extract_class_names(file_path):
        class_names = []
        with open(file_path, 'r') as file:
            for line in file:
                parts = line.split(' ', 1)
                if len(parts) == 2:
                    class_names.append(parts[1].strip())
        return class_names
    

    outpath = opt.outdir

    prompt = opt.prompt
    mprompt = opt.class_file
    eph = opt.epoch_num
    all_classnames = ['accordion', 'airplane', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'face', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'leopard', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'motorbike', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang',
    'abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'bengal', 'birman', 'bombay', 'boxer', 'british_shorthair', 'chihuahua', 'egyptian_mau', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'maine_coon', 'miniature_pinscher', 'newfoundland', 'persian', 'pomeranian', 'pug', 'ragdoll', 'russian_blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'siamese', 'sphynx', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier',
    'abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', 'amusement_arcade', 'amusement_park', 'anechoic_chamber', 'aquarium', 'aqueduct', 'arch', 'archive', 'art_gallery', 'art_school', 'art_studio', 'assembly_line', 'attic', 'auditorium', 'auto_factory', 'backseat car_interior', 'badlands', 'baggage_claim', 'ball_pit', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', 'barrel_storage wine_cellar', 'baseball stadium', 'baseball_field', 'basement', 'basilica', 'bathroom', 'batters_box', 'bayou', 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', 'block waterfall', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', 'botanical_garden', 'bottle_storage wine_cellar', 'bowling_alley', 'boxing_ring', 'bridge', 'broadleaf forest', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', 'butchers_shop', 'butte', 'cafeteria', 'campsite', 'campus', 'candy_store', 'canyon', 'carrousel', 'castle', 'catacomb', 'cemetery', 'chalet', 'cheese_factory', 'chemistry_lab', 'childs_room', 'classroom', 'clean_room', 'cliff', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', 'conference_center', 'conference_room', 'construction_site', 'control_room', 'coral_reef underwater', 'corn_field', 'corral', 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', 'creek', 'crevasse', 'crosswalk', 'cultivated field', 'dam', 'delicatessen', 'dentists_office', 'dining_car', 'dining_room', 'discotheque', 'dock', 'door elevator', 'dorm_room', 'driveway', 'drugstore', 'east_asia temple', 'electrical_substation', 'elevator_shaft', 'engine_room', 'establishment poolroom', 'excavation', 'exterior balcony', 'exterior covered_bridge', 'exterior gazebo', 'fairway', 'fan waterfall', 'fastfood_restaurant', 'fire_escape', 'fire_station', 'fishpond', 'food_court', 'football stadium', 'forest_path', 'forest_road', 'formal_garden', 'fountain', 'frontseat car_interior', 'galley', 'game_room', 'garbage_dump', 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'hayfield', 'heliport', 'herb_garden', 'highway', 'hill', 'home dinette', 'home poolroom', 'home_office', 'hospital', 'hospital_room', 'hot_spring', 'hotel_room', 'house', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', 'iceberg', 'igloo', 'indoor apse', 'indoor badminton_court', 'indoor bazaar', 'indoor bistro', 'indoor booth', 'indoor bow_window', 'indoor brewery', 'indoor casino', 'indoor cathedral', 'indoor cavern', 'indoor chicken_coop', 'indoor church', 'indoor cloister', 'indoor diner', 'indoor escalator', 'indoor factory', 'indoor firing_range', 'indoor florist_shop', 'indoor garage', 'indoor general_store', 'indoor greenhouse', 'indoor gymnasium', 'indoor hangar', 'indoor ice_skating_rink', 'indoor jacuzzi', 'indoor jail', 'indoor kennel', 'indoor library', 'indoor market', 'indoor mosque', 'indoor movie_theater', 'indoor museum', 'indoor parking_garage', 'indoor pilothouse', 'indoor podium', 'indoor pub', 'indoor shopping_mall', 'indoor stage', 'indoor swimming_pool', 'indoor synagogue', 'indoor tennis_court', 'indoor volleyball_court', 'indoor warehouse', 'indoor wrestling_ring', 'indoor_procenium theater', 'indoor_seats theater', 'industrial_area', 'interior balcony', 'interior elevator', 'islet', 'jail_cell', 'jewelry_shop', 'kasbah', 'kindergarden_classroom', 'kitchen', 'kitchenette', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', 'lift_bridge', 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', 'locker_room', 'mansion', 'manufactured_home', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain', 'mountain_snowy', 'music_store', 'music_studio', 'natural canal', 'natural lake', 'needleleaf forest', 'nursery', 'oast_house', 'ocean', 'office', 'office cubicle', 'office_building', 'oilrig', 'operating_room', 'orchard', 'outdoor apartment_building', 'outdoor arrival_gate', 'outdoor athletic_field', 'outdoor basketball_court', 'outdoor bazaar', 'outdoor bow_window', 'outdoor cabin', 'outdoor cathedral', 'outdoor chicken_coop', 'outdoor church', 'outdoor control_tower', 'outdoor diner', 'outdoor doorway', 'outdoor driving_range', 'outdoor general_store', 'outdoor greenhouse', 'outdoor hangar', 'outdoor hot_tub', 'outdoor hotel', 'outdoor hunting_lodge', 'outdoor ice_skating_rink', 'outdoor inn', 'outdoor kennel', 'outdoor labyrinth', 'outdoor library', 'outdoor lido_deck', 'outdoor market', 'outdoor monastery', 'outdoor mosque', 'outdoor nuclear_power_plant', 'outdoor observatory', 'outdoor oil_refinery', 'outdoor outhouse', 'outdoor parking_garage', 'outdoor planetarium', 'outdoor podium', 'outdoor power_plant', 'outdoor swimming_pool', 'outdoor synagogue', 'outdoor tennis_court', 'outdoor tent', 'outdoor track', 'outdoor volleyball_court', 'pagoda', 'palace', 'pantry', 'park', 'parking_lot', 'parlor', 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', 'physics_laboratory', 'picnic_area', 'platform subway_station', 'platform train_station', 'playground', 'playroom', 'plaza', 'plunge waterfall', 'pond', 'promenade_deck', 'public atrium', 'pulpit', 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sand desert', 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', 'shed', 'shoe_shop', 'shop bakery', 'shopfront', 'shower', 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'south_asia temple', 'squash_court', 'stable', 'staircase', 'street', 'subway_interior', 'supermarket', 'sushi_bar', 'swamp', 'television_studio', 'thriftshop', 'throne_room', 'ticket_booth', 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'train_railway', 'tree_farm', 'tree_house', 'trench', 'urban canal', 'utility_room', 'valley', 'van_interior', 'vegetable_garden', 'vegetation desert', 'vehicle dinette', 'veranda', 'veterinarians_office', 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', 'waiting_room', 'water moat', 'water_tower', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', 'wild field', 'wind_farm', 'windmill', 'yard', 'youth_hostel',
    'Apply_Eye_Makeup', 'Apply_Lipstick', 'Archery', 'Baby_Crawling', 'Balance_Beam', 'Band_Marching', 'Baseball_Pitch', 'Basketball', 'Basketball_Dunk', 'Bench_Press', 'Biking', 'Billiards', 'Blow_Dry_Hair', 'Blowing_Candles', 'Body_Weight_Squats', 'Bowling', 'Boxing_Punching_Bag', 'Boxing_Speed_Bag', 'Breast_Stroke', 'Brushing_Teeth', 'Clean_And_Jerk', 'Cliff_Diving', 'Cricket_Bowling', 'Cricket_Shot', 'Cutting_In_Kitchen', 'Diving', 'Drumming', 'Fencing', 'Field_Hockey_Penalty', 'Floor_Gymnastics', 'Frisbee_Catch', 'Front_Crawl', 'Golf_Swing', 'Haircut', 'Hammer_Throw', 'Hammering', 'Handstand_Pushups', 'Handstand_Walking', 'Head_Massage', 'High_Jump', 'Horse_Race', 'Horse_Riding', 'Hula_Hoop', 'Ice_Dancing', 'Javelin_Throw', 'Juggling_Balls', 'Jump_Rope', 'Jumping_Jack', 'Kayaking', 'Knitting', 'Long_Jump', 'Lunges', 'Military_Parade', 'Mixing', 'Mopping_Floor', 'Nunchucks', 'Parallel_Bars', 'Pizza_Tossing', 'Playing_Cello', 'Playing_Daf', 'Playing_Dhol', 'Playing_Flute', 'Playing_Guitar', 'Playing_Piano', 'Playing_Sitar', 'Playing_Tabla', 'Playing_Violin', 'Pole_Vault', 'Pommel_Horse', 'Pull_Ups', 'Punch', 'Push_Ups', 'Rafting', 'Rock_Climbing_Indoor', 'Rope_Climbing', 'Rowing', 'Salsa_Spin', 'Shaving_Beard', 'Shotput', 'Skate_Boarding', 'Skiing', 'Skijet', 'Sky_Diving', 'Soccer_Juggling', 'Soccer_Penalty', 'Still_Rings', 'Sumo_Wrestling', 'Surfing', 'Swing', 'Table_Tennis_Shot', 'Tai_Chi', 'Tennis_Swing', 'Throw_Discus', 'Trampoline_Jumping', 'Typing', 'Uneven_Bars', 'Volleyball_Spiking', 'Walking_With_Dog', 'Wall_Pushups', 'Writing_On_Board', 'Yo_Yo',
    'Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land', 'Highway or Road', 'Industrial Buildings', 'Pasture Land', 'Permanent Crop Land', 'Residential Buildings', 'River', 'Sea or Lake',
    'banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged',
    'DHC-8-100', 'A330-200', '737-900', '737-700', 'CRJ-700', 'A340-600', 'Tornado', '747-100', '747-200', 'Metroliner', 'ATR-42', '757-200', '737-400', 'ATR-72', 'Challenger 600', 'DC-6', 'A320', 'DC-8', 'Cessna 560', 'E-170', 'MD-90', 'MD-80', 'Gulfstream IV', 'Dornier 328', '737-600', 'Boeing 717', '737-500', 'A321', 'Falcon 2000', 'PA-28', '737-800', 'BAE-125', 'Fokker 100', '737-200', 'Cessna 208', 'F-16A/B', 'A319', 'MD-11', 'EMB-120', '747-400', '737-300', 'F/A-18', 'Beechcraft 1900', '767-200', 'A300B4', '747-300', 'SR-20', 'BAE 146-300', 'DHC-1', 'A310', 'Il-76', '777-300', 'ERJ 145', 'Tu-134', 'DC-9-30', 'Spitfire', 'C-47', 'C-130', 'An-12', '767-400', 'CRJ-900', 'Falcon 900', 'Saab 2000', '767-300', 'Embraer Legacy 600', 'Saab 340', 'BAE 146-200', 'Cessna 172', 'DHC-6', 'ERJ 135', 'A340-200', 'E-190', 'A380', 'Yak-42', '757-300', 'Hawk T1', 'DC-3', '707-320', 'A330-300', 'A340-300', 'Tu-154', 'Cessna 525', '777-200', 'DHC-8-300', 'Fokker 70', 'DH-82', 'E-195', 'DR-400', 'L-1011', 'Global Express', 'MD-87', 'A340-500', 'Gulfstream V', 'CRJ-200', 'Model B200', '727-200', 'Eurofighter Typhoon', 'A318', 'DC-10', 'Fokker 50',
    'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles',
    'alpine sea holly', 'anthurium', 'artichoke', 'azalea', 'ball moss', 'balloon flower', 'barbeton daisy', 'bearded iris', 'bee balm', 'bird of paradise', 'bishop of llandaff', 'black-eyed susan', 'blackberry lily', 'blanket flower', 'bolero deep blue', 'bougainvillea', 'bromelia', 'buttercup', 'californian poppy', 'camellia', 'canna lily', 'canterbury bells', 'cape flower', 'carnation', 'cautleya spicata', 'clematis', "colt's foot", 'columbine', 'common dandelion', 'corn poppy', 'cyclamen', 'daffodil', 'desert-rose', 'english marigold', 'fire lily', 'foxglove', 'frangipani', 'fritillary', 'garden phlox', 'gaura', 'gazania', 'geranium', 'giant white arum lily', 'globe thistle', 'globe-flower', 'grape hyacinth', 'great masterwort', 'hard-leaved pocket orchid', 'hibiscus', 'hippeastrum', 'japanese anemone', 'king protea', 'lenten rose', 'lotus', 'love in the mist', 'magnolia', 'mallow', 'marigold', 'mexican aster', 'mexican petunia', 'monkshood', 'moon orchid', 'morning glory', 'orange dahlia', 'osteospermum', 'oxeye daisy', 'passion flower', 'pelargonium', 'peruvian lily', 'petunia', 'pincushion flower', 'pink primrose', 'pink-yellow dahlia', 'poinsettia', 'primula', 'prince of wales feathers', 'purple coneflower', 'red ginger', 'rose', 'ruby-lipped cattleya', 'siam tulip', 'silverbush', 'snapdragon', 'spear thistle', 'spring crocus', 'stemless gentian', 'sunflower', 'sweet pea', 'sweet william', 'sword lily', 'thorn apple', 'tiger lily', 'toad lily', 'tree mallow', 'tree poppy', 'trumpet creeper', 'wallflower', 'water lily', 'watercress', 'wild pansy', 'windflower', 'yellow iris',
    '1991 Volkswagen Golf Hatchback', '1993 Geo Metro Convertible', '1993 Mercedes-Benz 300-Class Convertible', '1993 Volvo 240 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '1994 Audi V8 Sedan', '1997 Dodge Caravan Minivan', '1998 Eagle Talon Hatchback', '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2000 AM General Hummer SUV', '2001 Acura Integra Type R', '2001 Lamborghini Diablo Coupe', '2002 Daewoo Nubira Wagon', '2006 Ford GT Coupe', '2007 Audi S4 Sedan', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2007 Bentley Continental Flying Spur Sedan', '2007 Bentley Continental GT Coupe', '2007 Buick Rainier SUV', '2007 Cadillac Escalade EXT Crew Cab', '2007 Chevrolet Corvette Ron Fellows Edition Z06', '2007 Chevrolet Express Cargo Van', '2007 Chevrolet Express Van', '2007 Chevrolet Impala Sedan', '2007 Chevrolet Malibu Sedan', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Dodge Caliber Wagon', '2007 Dodge Dakota Club Cab', '2007 Dodge Durango SUV', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2007 Ford Freestar Minivan', '2007 Ford Mustang Convertible', '2007 Honda Odyssey Minivan', '2007 Hyundai Elantra Sedan', '2007 Suzuki Aerio Sedan', '2007 Volvo XC90 SUV', '2008 Acura TL Type-S', '2008 Audi RS 4 Convertible', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2008 Dodge Magnum Wagon', '2008 Isuzu Ascender SUV', '2008 Lamborghini Reventon Coupe', '2009 Bentley Arnage Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', '2009 Chevrolet TrailBlazer SS', '2009 Chrysler Aspen SUV', '2009 Dodge Charger SRT-8', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2009 Ford Expedition EL SUV', '2009 HUMMER H2 SUT Crew Cab', '2009 Mercedes-Benz SL-Class Coupe', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2010 BMW M5 Sedan', '2010 BMW M6 Convertible', '2010 Chevrolet Cobalt SS', '2010 Chevrolet HHR SS', '2010 Chevrolet Malibu Hybrid Sedan', '2010 Chrysler 300 SRT-8', '2010 Chrysler Sebring Convertible', '2010 Dodge Dakota Crew Cab', '2010 Dodge Ram Pickup 3500 Crew Cab', '2010 HUMMER H3T Crew Cab', '2011 Audi S6 Sedan', '2011 Audi TT Hatchback', '2011 Bentley Mulsanne Sedan', '2011 Dodge Challenger SRT8', '2011 Ford Ranger SuperCab', '2011 Infiniti QX56 SUV', '2011 Lincoln Town Car Sedan', '2011 Mazda Tribute SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2012 Acura TSX Sedan', '2012 Acura ZDX Hatchback', '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2012 Audi A5 Coupe', '2012 Audi R8 Coupe', '2012 Audi S4 Sedan', '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi TT RS Coupe', '2012 Audi TTS Coupe', '2012 BMW 1 Series Convertible', '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW M3 Coupe', '2012 BMW X3 SUV', '2012 BMW X6 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental GT Coupe', '2012 Bentley Continental Supersports Conv. Convertible', '2012 Buick Enclave SUV', '2012 Buick Regal GS', '2012 Buick Verano Sedan', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2012 Chevrolet Avalanche Crew Cab', '2012 Chevrolet Camaro Convertible', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2012 Chevrolet Silverado 2500HD Regular Cab', '2012 Chevrolet Sonic Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Traverse SUV', '2012 Chrysler Town and Country Minivan', '2012 Dodge Caliber Wagon', '2012 Dodge Charger Sedan', '2012 Dodge Durango SUV', '2012 Dodge Journey SUV', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari 458 Italia Convertible', '2012 Ferrari 458 Italia Coupe', '2012 Ferrari California Convertible', '2012 Ferrari FF Coupe', '2012 Fisker Karma Sedan', '2012 Ford E-Series Wagon Van', '2012 Ford Edge SUV', '2012 Ford F-150 Regular Cab', '2012 Ford F-450 Super Duty Crew Cab', '2012 Ford Fiesta Sedan', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '2012 GMC Savana Van', '2012 GMC Terrain SUV', '2012 GMC Yukon Hybrid SUV', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', '2012 Honda Odyssey Minivan', '2012 Hyundai Accent Sedan', '2012 Hyundai Azera Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Genesis Sedan', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Sonata Hybrid Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Tucson SUV', '2012 Hyundai Veloster Hatchback', '2012 Hyundai Veracruz SUV', '2012 Infiniti G Coupe IPL', '2012 Jaguar XK XKR', '2012 Jeep Compass SUV', '2012 Jeep Grand Cherokee SUV', '2012 Jeep Liberty SUV', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2012 Land Rover LR2 SUV', '2012 Land Rover Range Rover SUV', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2012 McLaren MP4-12C Coupe', '2012 Mercedes-Benz C-Class Sedan', '2012 Mercedes-Benz E-Class Sedan', '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Juke Hatchback', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota 4Runner SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', '2012 Toyota Sequoia SUV', '2012 Volkswagen Beetle Hatchback', '2012 Volkswagen Golf Hatchback', '2012 Volvo C30 Hatchback', '2012 smart fortwo Convertible',
    ]
    if "caltech" in prompt:
        classnames = ['accordion', 'airplane', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'face', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'leopard', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'motorbike', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["caltech101"]
    elif "pets" in prompt:
        classnames = ['abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'bengal', 'birman', 'bombay', 'boxer', 'british_shorthair', 'chihuahua', 'egyptian_mau', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'maine_coon', 'miniature_pinscher', 'newfoundland', 'persian', 'pomeranian', 'pug', 'ragdoll', 'russian_blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'siamese', 'sphynx', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["oxford_pets"]
    elif "sun" in prompt:
        classnames = ['abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', 'amusement_arcade', 'amusement_park', 'anechoic_chamber', 'aquarium', 'aqueduct', 'arch', 'archive', 'art_gallery', 'art_school', 'art_studio', 'assembly_line', 'attic', 'auditorium', 'auto_factory', 'backseat car_interior', 'badlands', 'baggage_claim', 'ball_pit', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', 'barrel_storage wine_cellar', 'baseball stadium', 'baseball_field', 'basement', 'basilica', 'bathroom', 'batters_box', 'bayou', 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', 'block waterfall', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', 'botanical_garden', 'bottle_storage wine_cellar', 'bowling_alley', 'boxing_ring', 'bridge', 'broadleaf forest', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', 'butchers_shop', 'butte', 'cafeteria', 'campsite', 'campus', 'candy_store', 'canyon', 'carrousel', 'castle', 'catacomb', 'cemetery', 'chalet', 'cheese_factory', 'chemistry_lab', 'childs_room', 'classroom', 'clean_room', 'cliff', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', 'conference_center', 'conference_room', 'construction_site', 'control_room', 'coral_reef underwater', 'corn_field', 'corral', 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', 'creek', 'crevasse', 'crosswalk', 'cultivated field', 'dam', 'delicatessen', 'dentists_office', 'dining_car', 'dining_room', 'discotheque', 'dock', 'door elevator', 'dorm_room', 'driveway', 'drugstore', 'east_asia temple', 'electrical_substation', 'elevator_shaft', 'engine_room', 'establishment poolroom', 'excavation', 'exterior balcony', 'exterior covered_bridge', 'exterior gazebo', 'fairway', 'fan waterfall', 'fastfood_restaurant', 'fire_escape', 'fire_station', 'fishpond', 'food_court', 'football stadium', 'forest_path', 'forest_road', 'formal_garden', 'fountain', 'frontseat car_interior', 'galley', 'game_room', 'garbage_dump', 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'hayfield', 'heliport', 'herb_garden', 'highway', 'hill', 'home dinette', 'home poolroom', 'home_office', 'hospital', 'hospital_room', 'hot_spring', 'hotel_room', 'house', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', 'iceberg', 'igloo', 'indoor apse', 'indoor badminton_court', 'indoor bazaar', 'indoor bistro', 'indoor booth', 'indoor bow_window', 'indoor brewery', 'indoor casino', 'indoor cathedral', 'indoor cavern', 'indoor chicken_coop', 'indoor church', 'indoor cloister', 'indoor diner', 'indoor escalator', 'indoor factory', 'indoor firing_range', 'indoor florist_shop', 'indoor garage', 'indoor general_store', 'indoor greenhouse', 'indoor gymnasium', 'indoor hangar', 'indoor ice_skating_rink', 'indoor jacuzzi', 'indoor jail', 'indoor kennel', 'indoor library', 'indoor market', 'indoor mosque', 'indoor movie_theater', 'indoor museum', 'indoor parking_garage', 'indoor pilothouse', 'indoor podium', 'indoor pub', 'indoor shopping_mall', 'indoor stage', 'indoor swimming_pool', 'indoor synagogue', 'indoor tennis_court', 'indoor volleyball_court', 'indoor warehouse', 'indoor wrestling_ring', 'indoor_procenium theater', 'indoor_seats theater', 'industrial_area', 'interior balcony', 'interior elevator', 'islet', 'jail_cell', 'jewelry_shop', 'kasbah', 'kindergarden_classroom', 'kitchen', 'kitchenette', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', 'lift_bridge', 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', 'locker_room', 'mansion', 'manufactured_home', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain', 'mountain_snowy', 'music_store', 'music_studio', 'natural canal', 'natural lake', 'needleleaf forest', 'nursery', 'oast_house', 'ocean', 'office', 'office cubicle', 'office_building', 'oilrig', 'operating_room', 'orchard', 'outdoor apartment_building', 'outdoor arrival_gate', 'outdoor athletic_field', 'outdoor basketball_court', 'outdoor bazaar', 'outdoor bow_window', 'outdoor cabin', 'outdoor cathedral', 'outdoor chicken_coop', 'outdoor church', 'outdoor control_tower', 'outdoor diner', 'outdoor doorway', 'outdoor driving_range', 'outdoor general_store', 'outdoor greenhouse', 'outdoor hangar', 'outdoor hot_tub', 'outdoor hotel', 'outdoor hunting_lodge', 'outdoor ice_skating_rink', 'outdoor inn', 'outdoor kennel', 'outdoor labyrinth', 'outdoor library', 'outdoor lido_deck', 'outdoor market', 'outdoor monastery', 'outdoor mosque', 'outdoor nuclear_power_plant', 'outdoor observatory', 'outdoor oil_refinery', 'outdoor outhouse', 'outdoor parking_garage', 'outdoor planetarium', 'outdoor podium', 'outdoor power_plant', 'outdoor swimming_pool', 'outdoor synagogue', 'outdoor tennis_court', 'outdoor tent', 'outdoor track', 'outdoor volleyball_court', 'pagoda', 'palace', 'pantry', 'park', 'parking_lot', 'parlor', 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', 'physics_laboratory', 'picnic_area', 'platform subway_station', 'platform train_station', 'playground', 'playroom', 'plaza', 'plunge waterfall', 'pond', 'promenade_deck', 'public atrium', 'pulpit', 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sand desert', 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', 'shed', 'shoe_shop', 'shop bakery', 'shopfront', 'shower', 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'south_asia temple', 'squash_court', 'stable', 'staircase', 'street', 'subway_interior', 'supermarket', 'sushi_bar', 'swamp', 'television_studio', 'thriftshop', 'throne_room', 'ticket_booth', 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'train_railway', 'tree_farm', 'tree_house', 'trench', 'urban canal', 'utility_room', 'valley', 'van_interior', 'vegetable_garden', 'vegetation desert', 'vehicle dinette', 'veranda', 'veterinarians_office', 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', 'waiting_room', 'water moat', 'water_tower', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', 'wild field', 'wind_farm', 'windmill', 'yard', 'youth_hostel']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["sun397"]
    elif "ucf101" in prompt:
        classnames = ['Apply_Eye_Makeup', 'Apply_Lipstick', 'Archery', 'Baby_Crawling', 'Balance_Beam', 'Band_Marching', 'Baseball_Pitch', 'Basketball', 'Basketball_Dunk', 'Bench_Press', 'Biking', 'Billiards', 'Blow_Dry_Hair', 'Blowing_Candles', 'Body_Weight_Squats', 'Bowling', 'Boxing_Punching_Bag', 'Boxing_Speed_Bag', 'Breast_Stroke', 'Brushing_Teeth', 'Clean_And_Jerk', 'Cliff_Diving', 'Cricket_Bowling', 'Cricket_Shot', 'Cutting_In_Kitchen', 'Diving', 'Drumming', 'Fencing', 'Field_Hockey_Penalty', 'Floor_Gymnastics', 'Frisbee_Catch', 'Front_Crawl', 'Golf_Swing', 'Haircut', 'Hammer_Throw', 'Hammering', 'Handstand_Pushups', 'Handstand_Walking', 'Head_Massage', 'High_Jump', 'Horse_Race', 'Horse_Riding', 'Hula_Hoop', 'Ice_Dancing', 'Javelin_Throw', 'Juggling_Balls', 'Jump_Rope', 'Jumping_Jack', 'Kayaking', 'Knitting', 'Long_Jump', 'Lunges', 'Military_Parade', 'Mixing', 'Mopping_Floor', 'Nunchucks', 'Parallel_Bars', 'Pizza_Tossing', 'Playing_Cello', 'Playing_Daf', 'Playing_Dhol', 'Playing_Flute', 'Playing_Guitar', 'Playing_Piano', 'Playing_Sitar', 'Playing_Tabla', 'Playing_Violin', 'Pole_Vault', 'Pommel_Horse', 'Pull_Ups', 'Punch', 'Push_Ups', 'Rafting', 'Rock_Climbing_Indoor', 'Rope_Climbing', 'Rowing', 'Salsa_Spin', 'Shaving_Beard', 'Shotput', 'Skate_Boarding', 'Skiing', 'Skijet', 'Sky_Diving', 'Soccer_Juggling', 'Soccer_Penalty', 'Still_Rings', 'Sumo_Wrestling', 'Surfing', 'Swing', 'Table_Tennis_Shot', 'Tai_Chi', 'Tennis_Swing', 'Throw_Discus', 'Trampoline_Jumping', 'Typing', 'Uneven_Bars', 'Volleyball_Spiking', 'Walking_With_Dog', 'Wall_Pushups', 'Writing_On_Board', 'Yo_Yo']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["ucf101"]
    elif "imagenet" in prompt:
        file_path = "/data8/user/CoOp/data/imagenet/classnames.txt"
        prompt = extract_class_names(file_path)
        prompt = ";;".join(prompt)
        datasets_path = datasets_paths["imagenet"]
        # prompt = ";;".join(prompt[:len(prompt)//2])
        # prompt = ";;".join(['totem pole', 'patas monkey', 'combine harvester', 'wok', 'purse'])
    elif "eurosat" in prompt:
        classnames = ['Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land', 'Highway or Road', 'Industrial Buildings', 'Pasture Land', 'Permanent Crop Land', 'Residential Buildings', 'River', 'Sea or Lake']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["eurosat"]
    elif "dtd" in prompt:
        classnames = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["dtd"]
    elif "aircraft" in prompt:
        classnames = ['DHC-8-100', 'A330-200', '737-900', '737-700', 'CRJ-700', 'A340-600', 'Tornado', '747-100', '747-200', 'Metroliner', 'ATR-42', '757-200', '737-400', 'ATR-72', 'Challenger 600', 'DC-6', 'A320', 'DC-8', 'Cessna 560', 'E-170', 'MD-90', 'MD-80', 'Gulfstream IV', 'Dornier 328', '737-600', 'Boeing 717', '737-500', 'A321', 'Falcon 2000', 'PA-28', '737-800', 'BAE-125', 'Fokker 100', '737-200', 'Cessna 208', 'F-16A/B', 'A319', 'MD-11', 'EMB-120', '747-400', '737-300', 'F/A-18', 'Beechcraft 1900', '767-200', 'A300B4', '747-300', 'SR-20', 'BAE 146-300', 'DHC-1', 'A310', 'Il-76', '777-300', 'ERJ 145', 'Tu-134', 'DC-9-30', 'Spitfire', 'C-47', 'C-130', 'An-12', '767-400', 'CRJ-900', 'Falcon 900', 'Saab 2000', '767-300', 'Embraer Legacy 600', 'Saab 340', 'BAE 146-200', 'Cessna 172', 'DHC-6', 'ERJ 135', 'A340-200', 'E-190', 'A380', 'Yak-42', '757-300', 'Hawk T1', 'DC-3', '707-320', 'A330-300', 'A340-300', 'Tu-154', 'Cessna 525', '777-200', 'DHC-8-300', 'Fokker 70', 'DH-82', 'E-195', 'DR-400', 'L-1011', 'Global Express', 'MD-87', 'A340-500', 'Gulfstream V', 'CRJ-200', 'Model B200', '727-200', 'Eurofighter Typhoon', 'A318', 'DC-10', 'Fokker 50']
        datasets_path = datasets_paths["fgvc_aircraft"]
        prompt = ";;".join(classnames)                
    elif "food" in prompt:
        classnames = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["food101"]
    elif "flowers" in prompt:
        classnames = ['alpine sea holly', 'anthurium', 'artichoke', 'azalea', 'ball moss', 'balloon flower', 'barbeton daisy', 'bearded iris', 'bee balm', 'bird of paradise', 'bishop of llandaff', 'black-eyed susan', 'blackberry lily', 'blanket flower', 'bolero deep blue', 'bougainvillea', 'bromelia', 'buttercup', 'californian poppy', 'camellia', 'canna lily', 'canterbury bells', 'cape flower', 'carnation', 'cautleya spicata', 'clematis', "colt's foot", 'columbine', 'common dandelion', 'corn poppy', 'cyclamen', 'daffodil', 'desert-rose', 'english marigold', 'fire lily', 'foxglove', 'frangipani', 'fritillary', 'garden phlox', 'gaura', 'gazania', 'geranium', 'giant white arum lily', 'globe thistle', 'globe-flower', 'grape hyacinth', 'great masterwort', 'hard-leaved pocket orchid', 'hibiscus', 'hippeastrum', 'japanese anemone', 'king protea', 'lenten rose', 'lotus', 'love in the mist', 'magnolia', 'mallow', 'marigold', 'mexican aster', 'mexican petunia', 'monkshood', 'moon orchid', 'morning glory', 'orange dahlia', 'osteospermum', 'oxeye daisy', 'passion flower', 'pelargonium', 'peruvian lily', 'petunia', 'pincushion flower', 'pink primrose', 'pink-yellow dahlia', 'poinsettia', 'primula', 'prince of wales feathers', 'purple coneflower', 'red ginger', 'rose', 'ruby-lipped cattleya', 'siam tulip', 'silverbush', 'snapdragon', 'spear thistle', 'spring crocus', 'stemless gentian', 'sunflower', 'sweet pea', 'sweet william', 'sword lily', 'thorn apple', 'tiger lily', 'toad lily', 'tree mallow', 'tree poppy', 'trumpet creeper', 'wallflower', 'water lily', 'watercress', 'wild pansy', 'windflower', 'yellow iris']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["oxford_flowers"]
    elif "stanford" in prompt:
        classnames = ['1991 Volkswagen Golf Hatchback', '1993 Geo Metro Convertible', '1993 Mercedes-Benz 300-Class Convertible', '1993 Volvo 240 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '1994 Audi V8 Sedan', '1997 Dodge Caravan Minivan', '1998 Eagle Talon Hatchback', '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2000 AM General Hummer SUV', '2001 Acura Integra Type R', '2001 Lamborghini Diablo Coupe', '2002 Daewoo Nubira Wagon', '2006 Ford GT Coupe', '2007 Audi S4 Sedan', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2007 Bentley Continental Flying Spur Sedan', '2007 Bentley Continental GT Coupe', '2007 Buick Rainier SUV', '2007 Cadillac Escalade EXT Crew Cab', '2007 Chevrolet Corvette Ron Fellows Edition Z06', '2007 Chevrolet Express Cargo Van', '2007 Chevrolet Express Van', '2007 Chevrolet Impala Sedan', '2007 Chevrolet Malibu Sedan', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Dodge Caliber Wagon', '2007 Dodge Dakota Club Cab', '2007 Dodge Durango SUV', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2007 Ford Freestar Minivan', '2007 Ford Mustang Convertible', '2007 Honda Odyssey Minivan', '2007 Hyundai Elantra Sedan', '2007 Suzuki Aerio Sedan', '2007 Volvo XC90 SUV', '2008 Acura TL Type-S', '2008 Audi RS 4 Convertible', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2008 Dodge Magnum Wagon', '2008 Isuzu Ascender SUV', '2008 Lamborghini Reventon Coupe', '2009 Bentley Arnage Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', '2009 Chevrolet TrailBlazer SS', '2009 Chrysler Aspen SUV', '2009 Dodge Charger SRT-8', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2009 Ford Expedition EL SUV', '2009 HUMMER H2 SUT Crew Cab', '2009 Mercedes-Benz SL-Class Coupe', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2010 BMW M5 Sedan', '2010 BMW M6 Convertible', '2010 Chevrolet Cobalt SS', '2010 Chevrolet HHR SS', '2010 Chevrolet Malibu Hybrid Sedan', '2010 Chrysler 300 SRT-8', '2010 Chrysler Sebring Convertible', '2010 Dodge Dakota Crew Cab', '2010 Dodge Ram Pickup 3500 Crew Cab', '2010 HUMMER H3T Crew Cab', '2011 Audi S6 Sedan', '2011 Audi TT Hatchback', '2011 Bentley Mulsanne Sedan', '2011 Dodge Challenger SRT8', '2011 Ford Ranger SuperCab', '2011 Infiniti QX56 SUV', '2011 Lincoln Town Car Sedan', '2011 Mazda Tribute SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2012 Acura TSX Sedan', '2012 Acura ZDX Hatchback', '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2012 Audi A5 Coupe', '2012 Audi R8 Coupe', '2012 Audi S4 Sedan', '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi TT RS Coupe', '2012 Audi TTS Coupe', '2012 BMW 1 Series Convertible', '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW M3 Coupe', '2012 BMW X3 SUV', '2012 BMW X6 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental GT Coupe', '2012 Bentley Continental Supersports Conv. Convertible', '2012 Buick Enclave SUV', '2012 Buick Regal GS', '2012 Buick Verano Sedan', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2012 Chevrolet Avalanche Crew Cab', '2012 Chevrolet Camaro Convertible', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2012 Chevrolet Silverado 2500HD Regular Cab', '2012 Chevrolet Sonic Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Traverse SUV', '2012 Chrysler Town and Country Minivan', '2012 Dodge Caliber Wagon', '2012 Dodge Charger Sedan', '2012 Dodge Durango SUV', '2012 Dodge Journey SUV', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari 458 Italia Convertible', '2012 Ferrari 458 Italia Coupe', '2012 Ferrari California Convertible', '2012 Ferrari FF Coupe', '2012 Fisker Karma Sedan', '2012 Ford E-Series Wagon Van', '2012 Ford Edge SUV', '2012 Ford F-150 Regular Cab', '2012 Ford F-450 Super Duty Crew Cab', '2012 Ford Fiesta Sedan', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '2012 GMC Savana Van', '2012 GMC Terrain SUV', '2012 GMC Yukon Hybrid SUV', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', '2012 Honda Odyssey Minivan', '2012 Hyundai Accent Sedan', '2012 Hyundai Azera Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Genesis Sedan', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Sonata Hybrid Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Tucson SUV', '2012 Hyundai Veloster Hatchback', '2012 Hyundai Veracruz SUV', '2012 Infiniti G Coupe IPL', '2012 Jaguar XK XKR', '2012 Jeep Compass SUV', '2012 Jeep Grand Cherokee SUV', '2012 Jeep Liberty SUV', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2012 Land Rover LR2 SUV', '2012 Land Rover Range Rover SUV', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2012 McLaren MP4-12C Coupe', '2012 Mercedes-Benz C-Class Sedan', '2012 Mercedes-Benz E-Class Sedan', '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Juke Hatchback', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota 4Runner SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', '2012 Toyota Sequoia SUV', '2012 Volkswagen Beetle Hatchback', '2012 Volkswagen Golf Hatchback', '2012 Volvo C30 Hatchback', '2012 smart fortwo Convertible']
        prompt = ";;".join(classnames)
        datasets_path = datasets_paths["stanford_cars"]
    elif "increasing" in mprompt:
        file_path = f"/data8/user/CoOp/data/imagenet/increasing_classes/{mprompt}.txt"
    else:
        file_path = f"/data8/user/CoOp/data/imagenet/random_classes/{mprompt}.txt"
    if len(prompt) == 0:
        prompt = extract_class_names(file_path)
        prompt = ";;".join(prompt)
    # prompt = ""
    if opt.start_t != -1:
        start_t = opt.start_t
    else:
        start_t = None
    
    
    adapter_weights_path = opt.adapter
    
    
    load_adp_weights = torch.load(adapter_weights_path)["state_dict"]


    if "fgvc" in prompt:
        clss = ['1991 Volkswagen Golf Hatchback', '1993 Geo Metro Convertible', '1993 Mercedes-Benz 300-Class Convertible', '1993 Volvo 240 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '1994 Audi V8 Sedan', '1997 Dodge Caravan Minivan', '1998 Eagle Talon Hatchback', '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2000 AM General Hummer SUV', '2001 Acura Integra Type R', '2001 Lamborghini Diablo Coupe', '2002 Daewoo Nubira Wagon', '2006 Ford GT Coupe', '2007 Audi S4 Sedan', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2007 Bentley Continental Flying Spur Sedan', '2007 Bentley Continental GT Coupe', '2007 Buick Rainier SUV', '2007 Cadillac Escalade EXT Crew Cab', '2007 Chevrolet Corvette Ron Fellows Edition Z06', '2007 Chevrolet Express Cargo Van', '2007 Chevrolet Express Van', '2007 Chevrolet Impala Sedan', '2007 Chevrolet Malibu Sedan', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Dodge Caliber Wagon', '2007 Dodge Dakota Club Cab', '2007 Dodge Durango SUV', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2007 Ford Freestar Minivan', '2007 Ford Mustang Convertible', '2007 Honda Odyssey Minivan', '2007 Hyundai Elantra Sedan', '2007 Suzuki Aerio Sedan', '2007 Volvo XC90 SUV', '2008 Acura TL Type-S', '2008 Audi RS 4 Convertible', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2008 Dodge Magnum Wagon', '2008 Isuzu Ascender SUV', '2008 Lamborghini Reventon Coupe', '2009 Bentley Arnage Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', '2009 Chevrolet TrailBlazer SS', '2009 Chrysler Aspen SUV', '2009 Dodge Charger SRT-8', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2009 Ford Expedition EL SUV', '2009 HUMMER H2 SUT Crew Cab', '2009 Mercedes-Benz SL-Class Coupe', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2010 BMW M5 Sedan', '2010 BMW M6 Convertible', '2010 Chevrolet Cobalt SS', '2010 Chevrolet HHR SS', '2010 Chevrolet Malibu Hybrid Sedan', '2010 Chrysler 300 SRT-8', '2010 Chrysler Sebring Convertible', '2010 Dodge Dakota Crew Cab', '2010 Dodge Ram Pickup 3500 Crew Cab', '2010 HUMMER H3T Crew Cab', '2011 Audi S6 Sedan', '2011 Audi TT Hatchback', '2011 Bentley Mulsanne Sedan', '2011 Dodge Challenger SRT8', '2011 Ford Ranger SuperCab', '2011 Infiniti QX56 SUV', '2011 Lincoln Town Car Sedan', '2011 Mazda Tribute SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2012 Acura TSX Sedan', '2012 Acura ZDX Hatchback', '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2012 Audi A5 Coupe', '2012 Audi R8 Coupe', '2012 Audi S4 Sedan', '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi TT RS Coupe', '2012 Audi TTS Coupe', '2012 BMW 1 Series Convertible', '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW M3 Coupe', '2012 BMW X3 SUV', '2012 BMW X6 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental GT Coupe', '2012 Bentley Continental Supersports Conv. Convertible', '2012 Buick Enclave SUV', '2012 Buick Regal GS', '2012 Buick Verano Sedan', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2012 Chevrolet Avalanche Crew Cab', '2012 Chevrolet Camaro Convertible', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2012 Chevrolet Silverado 2500HD Regular Cab', '2012 Chevrolet Sonic Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Traverse SUV', '2012 Chrysler Town and Country Minivan', '2012 Dodge Caliber Wagon', '2012 Dodge Charger Sedan', '2012 Dodge Durango SUV', '2012 Dodge Journey SUV', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari 458 Italia Convertible', '2012 Ferrari 458 Italia Coupe', '2012 Ferrari California Convertible', '2012 Ferrari FF Coupe', '2012 Fisker Karma Sedan', '2012 Ford E-Series Wagon Van', '2012 Ford Edge SUV', '2012 Ford F-150 Regular Cab', '2012 Ford F-450 Super Duty Crew Cab', '2012 Ford Fiesta Sedan', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '2012 GMC Savana Van', '2012 GMC Terrain SUV', '2012 GMC Yukon Hybrid SUV', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', '2012 Honda Odyssey Minivan', '2012 Hyundai Accent Sedan', '2012 Hyundai Azera Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Genesis Sedan', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Sonata Hybrid Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Tucson SUV', '2012 Hyundai Veloster Hatchback', '2012 Hyundai Veracruz SUV', '2012 Infiniti G Coupe IPL', '2012 Jaguar XK XKR', '2012 Jeep Compass SUV', '2012 Jeep Grand Cherokee SUV', '2012 Jeep Liberty SUV', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2012 Land Rover LR2 SUV', '2012 Land Rover Range Rover SUV', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2012 McLaren MP4-12C Coupe', '2012 Mercedes-Benz C-Class Sedan', '2012 Mercedes-Benz E-Class Sedan', '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Juke Hatchback', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota 4Runner SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', '2012 Toyota Sequoia SUV', '2012 Volkswagen Beetle Hatchback', '2012 Volkswagen Golf Hatchback', '2012 Volvo C30 Hatchback', '2012 smart fortwo Convertible']
    else:
        clss = [      "A300",
                            "A310",
                            "A320",
                            "A330",
                            "A340",
                            "A380",
                            "ATR-42",
                            "ATR-72",
                            "An-12",
                            "BAE 146",
                            "BAE-125",
                            "Beechcraft 1900",
                            "Boeing 707",
                            "Boeing 717",
                            "Boeing 727",
                            "Boeing 737",
                            "Boeing 747",
                            "Boeing 757",
                            "Boeing 767",
                            "Boeing 777",
                            "C-130",
                            "C-47",
                            "CRJ-200",
                            "CRJ-700",
                            "Cessna 172",
                            "Cessna 208",
                            "Cessna Citation",
                            "Challenger 600",
                            "DC-10",
                            "DC-3",
                            "DC-6",
                            "DC-8",
                            "DC-9",
                            "DH-82",
                            "DHC-1",
                            "DHC-6",
                            "DR-400",
                            "Dash 8",
                            "Dornier 328",
                            "EMB-120",
                            "Embraer E-Jet",
                            "Embraer ERJ 145",
                            "Embraer Legacy 600",
                            "Eurofighter Typhoon",
                            "F-16",
                            "F/A-18",
                            "Falcon 2000",
                            "Falcon 900",
                            "Fokker 100",
                            "Fokker 50",
                            "Fokker 70",
                            "Global Express",
                            "Gulfstream",
                            "Hawk T1",
                            "Il-76",
                            "King Air",
                            "L-1011",
                            "MD-11",
                            "MD-80",
                            "MD-90",
                            "Metroliner",
                            "PA-28",
                            "SR-20",
                            "Saab 2000",
                            "Saab 340",
                            "Spitfire",
                            "Tornado",
                            "Tu-134",
                            "Tu-154",
                            "Yak-42"
                        ]
    if "flowers" not in prompt and "imagenet" not in prompt:
        clss += ['alpine sea holly', 'anthurium', 'artichoke', 'azalea', 'ball moss', 'balloon flower', 'barbeton daisy', 'bearded iris', 'bee balm', 'bird of paradise', 'bishop of llandaff', 'black-eyed susan', 'blackberry lily', 'blanket flower', 'bolero deep blue', 'bougainvillea', 'bromelia', 'buttercup', 'californian poppy', 'camellia', 'canna lily', 'canterbury bells', 'cape flower', 'carnation', 'cautleya spicata', 'clematis', "colt's foot", 'columbine', 'common dandelion', 'corn poppy', 'cyclamen', 'daffodil', 'desert-rose', 'english marigold', 'fire lily', 'foxglove', 'frangipani', 'fritillary', 'garden phlox', 'gaura', 'gazania', 'geranium', 'giant white arum lily', 'globe thistle', 'globe-flower', 'grape hyacinth', 'great masterwort', 'hard-leaved pocket orchid', 'hibiscus', 'hippeastrum', 'japanese anemone', 'king protea', 'lenten rose', 'lotus', 'love in the mist', 'magnolia', 'mallow', 'marigold', 'mexican aster', 'mexican petunia', 'monkshood', 'moon orchid', 'morning glory', 'orange dahlia', 'osteospermum', 'oxeye daisy', 'passion flower', 'pelargonium', 'peruvian lily', 'petunia', 'pincushion flower', 'pink primrose', 'pink-yellow dahlia', 'poinsettia', 'primula', 'prince of wales feathers', 'purple coneflower', 'red ginger', 'rose', 'ruby-lipped cattleya', 'siam tulip', 'silverbush', 'snapdragon', 'spear thistle', 'spring crocus', 'stemless gentian', 'sunflower', 'sweet pea', 'sweet william', 'sword lily', 'thorn apple', 'tiger lily', 'toad lily', 'tree mallow', 'tree poppy', 'trumpet creeper', 'wallflower', 'water lily', 'watercress', 'wild pansy', 'windflower', 'yellow iris']
    # clss = remove_specific_classes(all_classnames, classnames)
    prompt2 = ";;".join(clss)
    # prompt = ";;".join(all_classnames)
    # prompt2 = ""

    if opt.composable_diffusion or opt.negation_diffusion:
        prompt = [prompt, prompt2]
        prompt2 = ""
    
    if opt.H == 16:
        confs = datasets_path.split('/')
        conf = os.path.join('/'.join(confs[:-1]), 'config.yaml')
        config = OmegaConf.load(conf)
        decoder_model = load_model_from_config(config, datasets_path)   
        
    start_noise = None
    x_start = None
    if opt.start_t != -1:
        for k, ii in load_adp_weights.items():
            if "ctx" in k:
                initial = load_adp_weights[k].to(torch.float32)
        posterior = decoder_model.encode(initial)
        noisy = posterior.sample()
        # noisy = zn.view(-1, 1, 96)
        # prompt = ""
        start_noise = []
        for idx in range(opt.n_samples):
            start_noisy = model.q_sample(x_start=noisy, t=torch.tensor([start_t], device=device))
            start_noise.append(start_noisy)
        start_noise = torch.cat(start_noise)

    if opt.negative_prompt:
        prompt2 = prompt
        prompt = ""
        
    if opt.use_dsname:
        prompt = opt.prompt
        prompt2 = "fgvc_aircraft"
    
    if opt.unconditional:
        prompt2 = ""

    all_samples=list()
    with torch.no_grad():
        with model.ema_scope():
            uc = None
            if opt.scale != 1.0:
                # uc = model.get_learned_conditioning(opt.n_samples * [""])#, reduce_token_length=True).float()
                uc = model.get_learned_conditioning(opt.n_samples * [prompt2])
            for n in trange(opt.n_iter, desc="Sampling"):
                if opt.composable_diffusion or opt.negation_diffusion:
                    c = model.get_learned_conditioning(prompt)    
                    if opt.scale != 1.0:
                        uc = model.get_learned_conditioning(opt.n_samples * [""])
                else:
                    c = model.get_learned_conditioning(opt.n_samples * [prompt])#, reduce_token_length=True).float()

                if opt.H == 20480:
                    shape = [1, 15, opt.H]
                elif opt.H == 2048:
                    shape = [1, 1, opt.H]
                elif opt.H == 1024:
                    shape = [1, 128, 8]
                elif opt.H == 16:
                    shape = [1, 8, 16]
                elif opt.ctx == 16:
                    shape = [1, 16, opt.H]
                else:
                    shape = [1, 4, opt.H]
                if opt.composable_diffusion:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                conditioning=c,
                                                batch_size=opt.n_samples,
                                                shape=shape,
                                                verbose=False,
                                                unconditional_guidance_scale=opt.scale,
                                                unconditional_conditioning=uc,
                                                x0=start_noise,
                                                eta=opt.ddim_eta,
                                                composable_diffusion=len(prompt)-1)
                elif opt.negation_diffusion:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                conditioning=c,
                                                batch_size=opt.n_samples,
                                                shape=shape,
                                                verbose=False,
                                                unconditional_guidance_scale=opt.scale,
                                                unconditional_conditioning=uc,
                                                x0=start_noise,
                                                eta=opt.ddim_eta,
                                                negation_diffusion=len(prompt)-1)
                else:
                    samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                    conditioning=c,
                                                    batch_size=opt.n_samples,
                                                    shape=shape,
                                                    verbose=False,
                                                    start_t=start_t,
                                                    x0=start_noise,
                                                    unconditional_guidance_scale=opt.scale,
                                                    unconditional_conditioning=uc,
                                                    eta=opt.ddim_eta)

                print(samples_ddim.shape)
                
                if opt.H == 16:
                    samples_ddim = decoder_model.decode(samples_ddim.squeeze(1))

                if opt.avg_samples:
                    samples_ddim = samples_ddim.mean(0).unsqueeze(0)
                if samples_ddim.shape[0] >= 1:
                    for kk in range(samples_ddim.shape[0]):
                        # samples_ddim_tmp = samples_ddim[kk]
                        if "increasing" in mprompt:
                            fold = "vit_b16_c4_ep10_batch2_ctxv1"
                        else:
                            fold = "vit_b16_c4_ep10_batch3_ctxv1"
                        
                        for k, ii in load_adp_weights.items():
                            if "ctx" in k:
                                load_adp_weights[k] = samples_ddim[kk].reshape(ii.shape)
                        
                        if opt.n_iter > 1:
                            dest_adapter_weights_path = adapter_weights_path.replace('model.pth.tar-50', f'model.pth.tar-{n+eph}')
                        elif opt.n_samples == 200 and samples_ddim.shape[0] == 200:
                            dest_adapter_weights_path = adapter_weights_path.replace('model.pth.tar-50', f'model.pth.tar-{kk+eph}')
                        else:
                            if samples_ddim.shape[0] == 1:
                                dest_adapter_weights_path = adapter_weights_path.replace('model.pth.tar-50', f'model.pth.tar-{eph}')
                            else:
                                dest_adapter_weights_path = adapter_weights_path.replace('model.pth.tar-50', f'model.pth.tar-{eph}{kk}')
                        outdict = {
                            'state_dict': load_adp_weights,
                            'epoch': eph,
                        }
                        torch.save(outdict, dest_adapter_weights_path)
                        