import os
import math
import random

from PIL import Image
import blobfile as bf
# from mpi4py import MPI
import pandas as pd
import numpy as np
import pickle
import json
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

# from dassl.data.datasets import Datum
from torchvision.transforms.functional import InterpolationMode
import torch
from collections import defaultdict


INTERPOLATION_MODES = {
    "bilinear": InterpolationMode.BILINEAR,
    "bicubic": InterpolationMode.BICUBIC,
    "nearest": InterpolationMode.NEAREST,
}
PIXEL_MEAN= [0.48145466, 0.4578275, 0.40821073]
PIXEL_STD= [0.26862954, 0.26130258, 0.27577711]
# Define the datasets and base directories
datasets = [
    "imagenet", "caltech101", "dtd", "eurosat", "fgvc_aircraft", 
    "food101", "oxford_flowers", "oxford_pets", "stanford_cars", "sun397", "ucf101"
]


def pad_coop_classnames(coop_classnames, pad_token="<PAD>"):
    # Find the maximum length of classnames in the batch
    max_length = 500 #max(len(cnames) for cnames in coop_classnames)
    
    # Pad all lists to the maximum length and track their original lengths
    # padded_classnames = []
    classnames_lengths = len(coop_classnames)
    
    # for cnames in coop_classnames:
    #     classnames_lengths.append(len(cnames))  # Track original length
    padded_classnames = coop_classnames + [pad_token] * (max_length - classnames_lengths)
        # padded_classnames.append(padded)
    
    return padded_classnames, classnames_lengths

def subsample(dataset):
    labels = set()
    for item in dataset:
        labels.add(item.label)
    labels = list(labels)
    labels.sort()
    n = len(labels)
    # Divide classes into two halves
    m = math.ceil(n / 2)

    # print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
    # if subsample == "base":
    selected = labels[:m]  # take the first half
    # else:
    #     selected = labels[m:]  # take the second half
    relabeler = {y: y_new for y_new, y in enumerate(selected)}
    
    
    dataset_new = []
    for item in dataset:
        if item.label not in selected:
            continue
        item_new = Datum(
            impath=item.impath,
            label=relabeler[item.label],
            classname=item.classname
        )
        dataset_new.append(item_new)
    
    return dataset_new

def get_classnames(prompt):
    if "caltech" in prompt:
        classnames = ['accordion', 'airplane', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'face', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'leopard', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'motorbike', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang']
        prompt = ";;".join(classnames)
    elif "pets" in prompt:
        classnames = ['abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'bengal', 'birman', 'bombay', 'boxer', 'british_shorthair', 'chihuahua', 'egyptian_mau', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'maine_coon', 'miniature_pinscher', 'newfoundland', 'persian', 'pomeranian', 'pug', 'ragdoll', 'russian_blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'siamese', 'sphynx', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']
        prompt = ";;".join(classnames)
    elif "sun" in prompt:
        classnames = ['abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', 'amusement_arcade', 'amusement_park', 'anechoic_chamber', 'aquarium', 'aqueduct', 'arch', 'archive', 'art_gallery', 'art_school', 'art_studio', 'assembly_line', 'attic', 'auditorium', 'auto_factory', 'backseat car_interior', 'badlands', 'baggage_claim', 'ball_pit', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', 'barrel_storage wine_cellar', 'baseball stadium', 'baseball_field', 'basement', 'basilica', 'bathroom', 'batters_box', 'bayou', 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', 'block waterfall', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', 'botanical_garden', 'bottle_storage wine_cellar', 'bowling_alley', 'boxing_ring', 'bridge', 'broadleaf forest', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', 'butchers_shop', 'butte', 'cafeteria', 'campsite', 'campus', 'candy_store', 'canyon', 'carrousel', 'castle', 'catacomb', 'cemetery', 'chalet', 'cheese_factory', 'chemistry_lab', 'childs_room', 'classroom', 'clean_room', 'cliff', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', 'conference_center', 'conference_room', 'construction_site', 'control_room', 'coral_reef underwater', 'corn_field', 'corral', 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', 'creek', 'crevasse', 'crosswalk', 'cultivated field', 'dam', 'delicatessen', 'dentists_office', 'dining_car', 'dining_room', 'discotheque', 'dock', 'door elevator', 'dorm_room', 'driveway', 'drugstore', 'east_asia temple', 'electrical_substation', 'elevator_shaft', 'engine_room', 'establishment poolroom', 'excavation', 'exterior balcony', 'exterior covered_bridge', 'exterior gazebo', 'fairway', 'fan waterfall', 'fastfood_restaurant', 'fire_escape', 'fire_station', 'fishpond', 'food_court', 'football stadium', 'forest_path', 'forest_road', 'formal_garden', 'fountain', 'frontseat car_interior', 'galley', 'game_room', 'garbage_dump', 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'hayfield', 'heliport', 'herb_garden', 'highway', 'hill', 'home dinette', 'home poolroom', 'home_office', 'hospital', 'hospital_room', 'hot_spring', 'hotel_room', 'house', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', 'iceberg', 'igloo', 'indoor apse', 'indoor badminton_court', 'indoor bazaar', 'indoor bistro', 'indoor booth', 'indoor bow_window', 'indoor brewery', 'indoor casino', 'indoor cathedral', 'indoor cavern', 'indoor chicken_coop', 'indoor church', 'indoor cloister', 'indoor diner', 'indoor escalator', 'indoor factory', 'indoor firing_range', 'indoor florist_shop', 'indoor garage', 'indoor general_store', 'indoor greenhouse', 'indoor gymnasium', 'indoor hangar', 'indoor ice_skating_rink', 'indoor jacuzzi', 'indoor jail', 'indoor kennel', 'indoor library', 'indoor market', 'indoor mosque', 'indoor movie_theater', 'indoor museum', 'indoor parking_garage', 'indoor pilothouse', 'indoor podium', 'indoor pub', 'indoor shopping_mall', 'indoor stage', 'indoor swimming_pool', 'indoor synagogue', 'indoor tennis_court', 'indoor volleyball_court', 'indoor warehouse', 'indoor wrestling_ring', 'indoor_procenium theater', 'indoor_seats theater', 'industrial_area', 'interior balcony', 'interior elevator', 'islet', 'jail_cell', 'jewelry_shop', 'kasbah', 'kindergarden_classroom', 'kitchen', 'kitchenette', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', 'lift_bridge', 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', 'locker_room', 'mansion', 'manufactured_home', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain', 'mountain_snowy', 'music_store', 'music_studio', 'natural canal', 'natural lake', 'needleleaf forest', 'nursery', 'oast_house', 'ocean', 'office', 'office cubicle', 'office_building', 'oilrig', 'operating_room', 'orchard', 'outdoor apartment_building', 'outdoor arrival_gate', 'outdoor athletic_field', 'outdoor basketball_court', 'outdoor bazaar', 'outdoor bow_window', 'outdoor cabin', 'outdoor cathedral', 'outdoor chicken_coop', 'outdoor church', 'outdoor control_tower', 'outdoor diner', 'outdoor doorway', 'outdoor driving_range', 'outdoor general_store', 'outdoor greenhouse', 'outdoor hangar', 'outdoor hot_tub', 'outdoor hotel', 'outdoor hunting_lodge', 'outdoor ice_skating_rink', 'outdoor inn', 'outdoor kennel', 'outdoor labyrinth', 'outdoor library', 'outdoor lido_deck', 'outdoor market', 'outdoor monastery', 'outdoor mosque', 'outdoor nuclear_power_plant', 'outdoor observatory', 'outdoor oil_refinery', 'outdoor outhouse', 'outdoor parking_garage', 'outdoor planetarium', 'outdoor podium', 'outdoor power_plant', 'outdoor swimming_pool', 'outdoor synagogue', 'outdoor tennis_court', 'outdoor tent', 'outdoor track', 'outdoor volleyball_court', 'pagoda', 'palace', 'pantry', 'park', 'parking_lot', 'parlor', 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', 'physics_laboratory', 'picnic_area', 'platform subway_station', 'platform train_station', 'playground', 'playroom', 'plaza', 'plunge waterfall', 'pond', 'promenade_deck', 'public atrium', 'pulpit', 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sand desert', 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', 'shed', 'shoe_shop', 'shop bakery', 'shopfront', 'shower', 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'south_asia temple', 'squash_court', 'stable', 'staircase', 'street', 'subway_interior', 'supermarket', 'sushi_bar', 'swamp', 'television_studio', 'thriftshop', 'throne_room', 'ticket_booth', 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'train_railway', 'tree_farm', 'tree_house', 'trench', 'urban canal', 'utility_room', 'valley', 'van_interior', 'vegetable_garden', 'vegetation desert', 'vehicle dinette', 'veranda', 'veterinarians_office', 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', 'waiting_room', 'water moat', 'water_tower', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', 'wild field', 'wind_farm', 'windmill', 'yard', 'youth_hostel']
        prompt = ";;".join(classnames)
    elif "ucf101" in prompt:
        classnames = ['Apply_Eye_Makeup', 'Apply_Lipstick', 'Archery', 'Baby_Crawling', 'Balance_Beam', 'Band_Marching', 'Baseball_Pitch', 'Basketball', 'Basketball_Dunk', 'Bench_Press', 'Biking', 'Billiards', 'Blow_Dry_Hair', 'Blowing_Candles', 'Body_Weight_Squats', 'Bowling', 'Boxing_Punching_Bag', 'Boxing_Speed_Bag', 'Breast_Stroke', 'Brushing_Teeth', 'Clean_And_Jerk', 'Cliff_Diving', 'Cricket_Bowling', 'Cricket_Shot', 'Cutting_In_Kitchen', 'Diving', 'Drumming', 'Fencing', 'Field_Hockey_Penalty', 'Floor_Gymnastics', 'Frisbee_Catch', 'Front_Crawl', 'Golf_Swing', 'Haircut', 'Hammer_Throw', 'Hammering', 'Handstand_Pushups', 'Handstand_Walking', 'Head_Massage', 'High_Jump', 'Horse_Race', 'Horse_Riding', 'Hula_Hoop', 'Ice_Dancing', 'Javelin_Throw', 'Juggling_Balls', 'Jump_Rope', 'Jumping_Jack', 'Kayaking', 'Knitting', 'Long_Jump', 'Lunges', 'Military_Parade', 'Mixing', 'Mopping_Floor', 'Nunchucks', 'Parallel_Bars', 'Pizza_Tossing', 'Playing_Cello', 'Playing_Daf', 'Playing_Dhol', 'Playing_Flute', 'Playing_Guitar', 'Playing_Piano', 'Playing_Sitar', 'Playing_Tabla', 'Playing_Violin', 'Pole_Vault', 'Pommel_Horse', 'Pull_Ups', 'Punch', 'Push_Ups', 'Rafting', 'Rock_Climbing_Indoor', 'Rope_Climbing', 'Rowing', 'Salsa_Spin', 'Shaving_Beard', 'Shotput', 'Skate_Boarding', 'Skiing', 'Skijet', 'Sky_Diving', 'Soccer_Juggling', 'Soccer_Penalty', 'Still_Rings', 'Sumo_Wrestling', 'Surfing', 'Swing', 'Table_Tennis_Shot', 'Tai_Chi', 'Tennis_Swing', 'Throw_Discus', 'Trampoline_Jumping', 'Typing', 'Uneven_Bars', 'Volleyball_Spiking', 'Walking_With_Dog', 'Wall_Pushups', 'Writing_On_Board', 'Yo_Yo']
        prompt = ";;".join(classnames)
    elif "imagenet" in prompt:
        file_path = "/data/user/CoOp/data/imagenet/classnames.txt"
        classnames = extract_class_names(file_path, use_full=False)
        prompt = ";;".join(classnames)
        # prompt = ";;".join(prompt[:len(prompt)//2])
        # prompt = ";;".join(['totem pole', 'patas monkey', 'combine harvester', 'wok', 'purse'])
    elif "eurosat" in prompt:
        classnames = ['Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land', 'Highway or Road', 'Industrial Buildings', 'Pasture Land', 'Permanent Crop Land', 'Residential Buildings', 'River', 'Sea or Lake']
        prompt = ";;".join(classnames)
    elif "dtd" in prompt:
        classnames = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged']
        prompt = ";;".join(classnames)
    elif "aircraft" in prompt:
        classnames = [      "A300",
                            "A310",
                            "A320",
                            "A330",
                            "A340",
                            "A380",
                            "ATR-42",
                            "ATR-72",
                            "An-12",
                            "BAE 146",
                            "BAE-125",
                            "Beechcraft 1900",
                            "Boeing 707",
                            "Boeing 717",
                            "Boeing 727",
                            "Boeing 737",
                            "Boeing 747",
                            "Boeing 757",
                            "Boeing 767",
                            "Boeing 777",
                            "C-130",
                            "C-47",
                            "CRJ-200",
                            "CRJ-700",
                            "Cessna 172",
                            "Cessna 208",
                            "Cessna Citation",
                            "Challenger 600",
                            "DC-10",
                            "DC-3",
                            "DC-6",
                            "DC-8",
                            "DC-9",
                            "DH-82",
                            "DHC-1",
                            "DHC-6",
                            "DR-400",
                            "Dash 8",
                            "Dornier 328",
                            "EMB-120",
                            "Embraer E-Jet",
                            "Embraer ERJ 145",
                            "Embraer Legacy 600",
                            "Eurofighter Typhoon",
                            "F-16",
                            "F/A-18",
                            "Falcon 2000",
                            "Falcon 900",
                            "Fokker 100",
                            "Fokker 50",
                            "Fokker 70",
                            "Global Express",
                            "Gulfstream",
                            "Hawk T1",
                            "Il-76",
                            "King Air",
                            "L-1011",
                            "MD-11",
                            "MD-80",
                            "MD-90",
                            "Metroliner",
                            "PA-28",
                            "SR-20",
                            "Saab 2000",
                            "Saab 340",
                            "Spitfire",
                            "Tornado",
                            "Tu-134",
                            "Tu-154",
                            "Yak-42"
                        ]
        prompt = ";;".join(classnames)                
    elif "food" in prompt:
        classnames = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']
        prompt = ";;".join(classnames)
    elif "flowers" in prompt:
        classnames = ['alpine sea holly', 'anthurium', 'artichoke', 'azalea', 'ball moss', 'balloon flower', 'barbeton daisy', 'bearded iris', 'bee balm', 'bird of paradise', 'bishop of llandaff', 'black-eyed susan', 'blackberry lily', 'blanket flower', 'bolero deep blue', 'bougainvillea', 'bromelia', 'buttercup', 'californian poppy', 'camellia', 'canna lily', 'canterbury bells', 'cape flower', 'carnation', 'cautleya spicata', 'clematis', "colt's foot", 'columbine', 'common dandelion', 'corn poppy', 'cyclamen', 'daffodil', 'desert-rose', 'english marigold', 'fire lily', 'foxglove', 'frangipani', 'fritillary', 'garden phlox', 'gaura', 'gazania', 'geranium', 'giant white arum lily', 'globe thistle', 'globe-flower', 'grape hyacinth', 'great masterwort', 'hard-leaved pocket orchid', 'hibiscus', 'hippeastrum', 'japanese anemone', 'king protea', 'lenten rose', 'lotus', 'love in the mist', 'magnolia', 'mallow', 'marigold', 'mexican aster', 'mexican petunia', 'monkshood', 'moon orchid', 'morning glory', 'orange dahlia', 'osteospermum', 'oxeye daisy', 'passion flower', 'pelargonium', 'peruvian lily', 'petunia', 'pincushion flower', 'pink primrose', 'pink-yellow dahlia', 'poinsettia', 'primula', 'prince of wales feathers', 'purple coneflower', 'red ginger', 'rose', 'ruby-lipped cattleya', 'siam tulip', 'silverbush', 'snapdragon', 'spear thistle', 'spring crocus', 'stemless gentian', 'sunflower', 'sweet pea', 'sweet william', 'sword lily', 'thorn apple', 'tiger lily', 'toad lily', 'tree mallow', 'tree poppy', 'trumpet creeper', 'wallflower', 'water lily', 'watercress', 'wild pansy', 'windflower', 'yellow iris']
        prompt = ";;".join(classnames)
    elif "stanford" in prompt:
        classnames = ['1991 Volkswagen Golf Hatchback', '1993 Geo Metro Convertible', '1993 Mercedes-Benz 300-Class Convertible', '1993 Volvo 240 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '1994 Audi V8 Sedan', '1997 Dodge Caravan Minivan', '1998 Eagle Talon Hatchback', '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2000 AM General Hummer SUV', '2001 Acura Integra Type R', '2001 Lamborghini Diablo Coupe', '2002 Daewoo Nubira Wagon', '2006 Ford GT Coupe', '2007 Audi S4 Sedan', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2007 Bentley Continental Flying Spur Sedan', '2007 Bentley Continental GT Coupe', '2007 Buick Rainier SUV', '2007 Cadillac Escalade EXT Crew Cab', '2007 Chevrolet Corvette Ron Fellows Edition Z06', '2007 Chevrolet Express Cargo Van', '2007 Chevrolet Express Van', '2007 Chevrolet Impala Sedan', '2007 Chevrolet Malibu Sedan', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Dodge Caliber Wagon', '2007 Dodge Dakota Club Cab', '2007 Dodge Durango SUV', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2007 Ford Freestar Minivan', '2007 Ford Mustang Convertible', '2007 Honda Odyssey Minivan', '2007 Hyundai Elantra Sedan', '2007 Suzuki Aerio Sedan', '2007 Volvo XC90 SUV', '2008 Acura TL Type-S', '2008 Audi RS 4 Convertible', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2008 Dodge Magnum Wagon', '2008 Isuzu Ascender SUV', '2008 Lamborghini Reventon Coupe', '2009 Bentley Arnage Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', '2009 Chevrolet TrailBlazer SS', '2009 Chrysler Aspen SUV', '2009 Dodge Charger SRT-8', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2009 Ford Expedition EL SUV', '2009 HUMMER H2 SUT Crew Cab', '2009 Mercedes-Benz SL-Class Coupe', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2010 BMW M5 Sedan', '2010 BMW M6 Convertible', '2010 Chevrolet Cobalt SS', '2010 Chevrolet HHR SS', '2010 Chevrolet Malibu Hybrid Sedan', '2010 Chrysler 300 SRT-8', '2010 Chrysler Sebring Convertible', '2010 Dodge Dakota Crew Cab', '2010 Dodge Ram Pickup 3500 Crew Cab', '2010 HUMMER H3T Crew Cab', '2011 Audi S6 Sedan', '2011 Audi TT Hatchback', '2011 Bentley Mulsanne Sedan', '2011 Dodge Challenger SRT8', '2011 Ford Ranger SuperCab', '2011 Infiniti QX56 SUV', '2011 Lincoln Town Car Sedan', '2011 Mazda Tribute SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2012 Acura TSX Sedan', '2012 Acura ZDX Hatchback', '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2012 Audi A5 Coupe', '2012 Audi R8 Coupe', '2012 Audi S4 Sedan', '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi TT RS Coupe', '2012 Audi TTS Coupe', '2012 BMW 1 Series Convertible', '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW M3 Coupe', '2012 BMW X3 SUV', '2012 BMW X6 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental GT Coupe', '2012 Bentley Continental Supersports Conv. Convertible', '2012 Buick Enclave SUV', '2012 Buick Regal GS', '2012 Buick Verano Sedan', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2012 Chevrolet Avalanche Crew Cab', '2012 Chevrolet Camaro Convertible', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2012 Chevrolet Silverado 2500HD Regular Cab', '2012 Chevrolet Sonic Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Traverse SUV', '2012 Chrysler Town and Country Minivan', '2012 Dodge Caliber Wagon', '2012 Dodge Charger Sedan', '2012 Dodge Durango SUV', '2012 Dodge Journey SUV', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari 458 Italia Convertible', '2012 Ferrari 458 Italia Coupe', '2012 Ferrari California Convertible', '2012 Ferrari FF Coupe', '2012 Fisker Karma Sedan', '2012 Ford E-Series Wagon Van', '2012 Ford Edge SUV', '2012 Ford F-150 Regular Cab', '2012 Ford F-450 Super Duty Crew Cab', '2012 Ford Fiesta Sedan', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '2012 GMC Savana Van', '2012 GMC Terrain SUV', '2012 GMC Yukon Hybrid SUV', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', '2012 Honda Odyssey Minivan', '2012 Hyundai Accent Sedan', '2012 Hyundai Azera Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Genesis Sedan', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Sonata Hybrid Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Tucson SUV', '2012 Hyundai Veloster Hatchback', '2012 Hyundai Veracruz SUV', '2012 Infiniti G Coupe IPL', '2012 Jaguar XK XKR', '2012 Jeep Compass SUV', '2012 Jeep Grand Cherokee SUV', '2012 Jeep Liberty SUV', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2012 Land Rover LR2 SUV', '2012 Land Rover Range Rover SUV', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2012 McLaren MP4-12C Coupe', '2012 Mercedes-Benz C-Class Sedan', '2012 Mercedes-Benz E-Class Sedan', '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Juke Hatchback', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota 4Runner SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', '2012 Toyota Sequoia SUV', '2012 Volkswagen Beetle Hatchback', '2012 Volkswagen Golf Hatchback', '2012 Volvo C30 Hatchback', '2012 smart fortwo Convertible']
        prompt = ";;".join(classnames)
    return classnames

def get_classnames_all(prompt, comb_dict):
    if "caltech" in prompt:
        classnames = list(comb_dict['caltech101'].values())
    elif "pets" in prompt:
        classnames = list(comb_dict['oxford_pets'].values())
    elif "sun" in prompt:
        classnames = list(comb_dict['sun397'].values())
    elif "ucf101" in prompt:
        classnames = list(comb_dict['ucf101'].values())
    elif "imagenet" in prompt:
        classnames = list(comb_dict['imagenet'].values())
    elif "eurosat" in prompt:
        classnames = list(comb_dict['eurosat'].values())
    elif "dtd" in prompt:
        classnames = list(comb_dict['dtd'].values())
    elif "aircraft" in prompt:
        classnames = list(comb_dict['fgvc_aircraft'].values())
    elif "food" in prompt:
        classnames = list(comb_dict['food101'].values())
    elif "flowers" in prompt:
        classnames = list(comb_dict['oxford_flowers'].values())
    elif "stanford" in prompt:
        classnames = list(comb_dict['stanford_cars'].values())
    return classnames

def extract_class_names(file_path, use_full=False):
    class_names = []
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.split(' ', 1)
            if len(parts) == 2:
                class_names.append(parts[1].strip())
    if not use_full:
        num_c = len(class_names)
        return class_names[:num_c//2]
    return class_names

base_dirs = {
    "imagenet": "/data/user/CoOp/data/imagenet/increasing_classes_1000_subsample/split_fewshot/",
    "caltech101": "/data/user/CoOp/data/caltech-101/split_fewshot/",
    "dtd": "/data/user/CoOp/data/dtd/split_fewshot/",
    "eurosat": "/data/user/CoOp/data/eurosat/split_fewshot/",
    "fgvc_aircraft": "/data/user/CoOp/data/fgvc_aircraft/split_fewshot/",
    "food101": "/data/user/CoOp/data/food-101/split_fewshot/",
    "oxford_flowers": "/data/user/CoOp/data/oxford_flowers/split_fewshot/",
    "oxford_pets": "/data/user/CoOp/data/oxford_pets/split_fewshot/",
    "stanford_cars": "/data/user/CoOp/data/stanford_cars/split_fewshot/",
    "sun397": "/data/user/CoOp/data/sun397/split_fewshot/",
    "ucf101": "/data/user/CoOp/data/ucf101/split_fewshot/"
}


# Define the total number of seeds per dataset
num_seeds = 200
total_datasets = len(datasets)
total_indices = total_datasets * num_seeds

def map_idx_to_filename(idx):
    # Check if idx is within the valid range
    if idx < 0 or idx >= total_indices:
        raise ValueError(f"Index {idx} is out of range. It must be between 0 and {total_indices - 1}.")
    
    # Calculate the dataset index and the seed number
    dataset_idx = idx // num_seeds
    seed_number = (idx % num_seeds) + 1  # Seeds start from 1, not 0
    
    # Get the corresponding dataset name
    dataset_name = datasets[dataset_idx]
    
    # Construct the file name
    file_name = f"shot_16-seed_{seed_number}.pkl"
    
    # Get the base directory for the dataset
    base_dir = base_dirs.get(dataset_name, f"/data/user/CoOp/data/{dataset_name}/split_fewshot/")
    
    # Construct the full file path
    file_path = os.path.join(base_dir, file_name)
    # Construct the model ckpt path
    if 'imagenet' in dataset_name:
        model_path = f"/data/user/CoOp/output/{dataset_name}/CoOp/vit_b16_ep50_16shots/nctx4_cscFalse_ctpend_1000/seed{seed_number}/prompt_learner/model.pth.tar-50"
    else:
        model_path = f"/data/user/CoOp/output/{dataset_name}/CoOp/vit_b16_ep50_16shots/nctx4_cscFalse_ctpend/seed{seed_number}/prompt_learner/model.pth.tar-50"
    
    return file_path, model_path


class BatchColorize(object):
    def __init__(self, n=150):
        self.cmap = color_map(n)
        # self.cmap = torch.from_numpy(self.cmap[:n])

    def __call__(self, gray_image):
        size = gray_image.shape
        color_image = np.zeros((size[0], 3, size[1], size[2]), dtype=np.float32)

        for label in range(0, len(self.cmap)):
            mask = (label == gray_image)
            color_image[:,0][mask] = self.cmap[label][0]
            color_image[:,1][mask] = self.cmap[label][1]
            color_image[:,2][mask] = self.cmap[label][2]

        # handle void
        mask = (255 == gray_image)
        color_image[:,0][mask] = color_image[:,1][mask] = color_image[:,2][mask] = 255

        return color_image
    
class BatchDeColorize(object):
    def __init__(self, n=40):
        self.cmap = color_map(n)
        # self.cmap = torch.from_numpy(self.cmap[:n])

    def __call__(self, rgb_image):
        size = rgb_image.shape
        gray_image = np.zeros((size[0], size[2], size[3]), dtype=np.float32) - 1
        

        for label in range(0, len(self.cmap)):
            tmp = np.zeros_like(rgb_image)
            tmp[:,0] = self.cmap[label][0]
            tmp[:,1] = self.cmap[label][1]
            tmp[:,2] = self.cmap[label][2]
            mask = (tmp == rgb_image)
            m = np.prod(mask, 1).astype(bool)
            gray_image[m] = label            

        # handle void
        mask = (-1 == gray_image)
        gray_image[mask] = 255

        return gray_image[0]


def color_map(N=256, normalized=False):
    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    dtype = 'float32' if normalized else 'uint8'
    cmap = np.zeros((N, 3), dtype=dtype)
    for i in range(N):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | (bitget(c, 0) << 7-j)
            g = g | (bitget(c, 1) << 7-j)
            b = b | (bitget(c, 2) << 7-j)
            c = c >> 3

        cmap[i] = np.array([r, g, b])

    cmap = cmap/255 if normalized else cmap
    return cmap


def _convert_image_to_rgb(image):
    return image.convert("RGB")
    
def preprocess_clip(n_px=224):
    return Compose([
        Resize(n_px, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

def load_data(
    *,
    dataset_mode,
    data_dir,
    image_size,
    random_crop=True,
    random_flip=True,
    is_train=True,
    use_pose=False,
    class_cond=False,
    deterministic=False,
    batch_size=1,    
    shard=-1,
    use_map=False,
    use1d=False,
    use2d=False,
    use_ti=False,
    usemeta=False,
    use_one_shard=False,
    use_one_embedding=False,
    subsample=False,
    coop_loss=False,
    decoder_only=False,
    dataset_name=False,
    load_sd_lora=False,
    use_null_prompt=True,
    image_adapter=False,
    random_conditioning=True,
):
    """
    For a dataset, create a generator over (images, kwargs) pairs.

    Each images is an NCHW float tensor, and the kwargs dict contains zero or
    more keys, each of which map to a batched Tensor of their own.
    The kwargs dict can be used for class labels, in which case the key is "y"
    and the values are integer tensors of class labels.

    :param data_dir: a dataset directory.
    :param batch_size: the batch size of each returned pair.
    :param image_size: the size to which images are resized.
    :param class_cond: if True, include a "y" key in returned dicts for class
                       label. If classes are not available and this is true, an
                       exception will be raised.
    :param deterministic: if True, yield results in a deterministic order.
    :param random_crop: if True, randomly crop the images for augmentation.
    :param random_flip: if True, randomly flip the images for augmentation.
    """
    if not data_dir:
        raise ValueError("unspecified data directory")
    all_files2 = None
    classes2 = None

    if dataset_mode == 'coop_variations':
        load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_cross_dataset_500_class_ctx_values_filtered.pt", map_location=torch.device('cpu'))
        load_adp_weights2 = torch.load("/data/user/CoOp/scripts/coop/coop_cross_dataset_500_class_ctx_values_removed_seed1.pt", map_location=torch.device('cpu'))
        all_files = list(load_adp_weights.values())
        classes = list(load_adp_weights.keys())
        all_files2 = load_adp_weights2
        classes2 = list(load_adp_weights2.keys())
        instances = None
        poses = None
    elif dataset_mode == 'prompt_diffusion':        
        load_adp_weights = torch.load(f"/data/user/CoOp/scripts/coop/{data_dir}", map_location=torch.device('cpu'))
        data_dir2 = data_dir.replace('target','cond')
        load_adp_weights2 = torch.load(f"/data/user/CoOp/scripts/coop/{data_dir2}", map_location=torch.device('cpu'))
        all_files = list(load_adp_weights.values())
        classes = list(load_adp_weights.keys())
        all_files2 = list(load_adp_weights2.values())
        
        classes2 = list(load_adp_weights2.keys())
        instances = None
        poses = None
    elif dataset_mode == 'coop':
        
        if 'coop_imagenet_full_ctx_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_imagenet_full_ctx_values_dict.pt", map_location=torch.device('cpu'))
            load_adp_weights2 = torch.load("/data/user/CoOp/scripts/coop/coop_increasing_random_ctx_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_imagenet_full_all_ctx_meta_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_imagenet_full_all_ctx_meta_values_dict.pt", map_location=torch.device('cpu'))
            load_adp_weights2 = torch.load("/data/user/CoOp/scripts/coop/coop_increasing_random_all_ctx_meta_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_increasing_random_ctx' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_increasing_random_ctx_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_increasing_random_all_ctx_meta_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_increasing_random_all_ctx_meta_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_10_random_ctx' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_10_random_ctx_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_100k_10_random_ctx_16_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_100k_10_random_ctx_16_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_10_class_ctx_values_dict_4ctx' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_10_class_ctx_values_dict_4ctx.pt", map_location=torch.device('cpu'))
        elif 'coop_500_class_ctx_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_500_class_ctx_values_dict.pt", map_location=torch.device('cpu'))
        elif 'latent_coop_cross_dataset_500_class_ctx_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/latent_coop_cross_dataset_500_class_ctx_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_cross_dataset_500_class_ctx_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_cross_dataset_500_class_ctx_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_cross_dataset_1000_class_ctx_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_cross_dataset_1000_class_ctx_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_10_class_ctx' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_10_class_ctx_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_10_class_all_ctx_meta_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_10_class_all_ctx_meta_values_dict.pt", map_location=torch.device('cpu'))
        elif 'coop_ctx_values_dict' in data_dir:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_ctx_values_dict.pt", map_location=torch.device('cpu'))
        else:
            if '.pt' in data_dir:
                load_adp_weights = torch.load(f"/data/user/CoOp/scripts/coop/{data_dir}", map_location=torch.device('cpu'))
            else:
                load_adp_weights = torch.load(f"/data/user/CoOp/scripts/coop/{data_dir}.pt", map_location=torch.device('cpu'))
        if decoder_only:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/decoded_latent_coop_cross_dataset_500_class_ctx_values_dict.pt", map_location=torch.device('cpu'))
        all_files = list(load_adp_weights.values())
        classes = list(load_adp_weights.keys())
        if 'coop_imagenet_full_ctx_values_dict' in data_dir:
            all_files += list(load_adp_weights2.values())
            classes += list(load_adp_weights2.keys())
        instances = None
        if decoder_only:
            load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/latent_coop_cross_dataset_500_class_ctx_values_dict.pt", map_location=torch.device('cpu'))
            instances = list(load_adp_weights.values())
        poses=None
    elif dataset_mode == 'identity_prompts': 
        if is_train:
            if subsample:
                shard_dir = f"/data/user/diffusers/examples/textual_inversion/{data_dir}"
            else: 
                shard_dir = "/data/user/diffusers/examples/textual_inversion/identities_shards"
            if use_one_embedding:
                all_files = []
                classes = []
                # shard_dir = '/data/user/diffusers/examples/textual_inversion/all100identities_shards/'
                # Create a dictionary to track the highest step for each identity
                identity_to_highest_step = defaultdict(lambda: (None, -1))  # (file, highest_step)

                for item in os.listdir(shard_dir):
                    # Load the weights
                    load_adp_weights = torch.load(os.path.join(shard_dir, item), map_location=torch.device('cpu'))
                    for key, value in load_adp_weights.items():
                        # Parse the identity and step from the key
                        # Example key: 'identity-99-step-500'
                        identity, _, step = key.rsplit('-', 2)
                        # print(identity, step)
                        step = int(step.split('-')[-1])  # Extract step as integer
                        
                        # Update the dictionary if this step is higher
                        if step > identity_to_highest_step[identity][1]:
                            identity_to_highest_step[identity] = (value, step)

                # Extract the highest-step files and their corresponding identities
                for identity, (file, _) in identity_to_highest_step.items():
                    all_files.append(file)
                    classes.append(identity)
            else:
                all_files = []
                classes = []
                for item in os.listdir(shard_dir):
                    load_adp_weights = torch.load(os.path.join(shard_dir, item), map_location=torch.device('cpu'))
                    files = list(load_adp_weights.values())
                    clss = list(load_adp_weights.keys())
                    all_files.extend(files)
                    classes.extend(clss)
                    if use_one_shard:
                        # print(classes)
                        break
        else:
            if subsample:
                shard_dir = f"/data/user/diffusers/examples/textual_inversion/{data_dir}.pt"
            else:
                shard_dir = "/data/user/diffusers/examples/textual_inversion/val_sd_identity_prompts.pt"
            load_adp_weights = torch.load(shard_dir, map_location=torch.device('cpu'))
            all_files = list(load_adp_weights.values())
            classes = list(load_adp_weights.keys())
        instances = None
        poses=None
    elif dataset_mode == 'joint':
        if is_train:
            if subsample:
                shard_dir = f"/data/user/diffusers/examples/textual_inversion/{data_dir}"
            else: 
                shard_dir = "/data/user/diffusers/examples/textual_inversion/identities_shards"
            if use_one_embedding:
                all_files = []
                classes = []
                # shard_dir = '/data/user/diffusers/examples/textual_inversion/all100identities_shards/'
                # Create a dictionary to track the highest step for each identity
                identity_to_highest_step = defaultdict(lambda: (None, -1))  # (file, highest_step)

                for item in os.listdir(shard_dir):
                    # Load the weights
                    load_adp_weights = torch.load(os.path.join(shard_dir, item), map_location=torch.device('cpu'))
                    for key, value in load_adp_weights.items():
                        # Parse the identity and step from the key
                        # Example key: 'identity-99-step-500'
                        identity, _, step = key.rsplit('-', 2)
                        # print(identity, step)
                        step = int(step.split('-')[-1])  # Extract step as integer
                        
                        # Update the dictionary if this step is higher
                        if step > identity_to_highest_step[identity][1]:
                            identity_to_highest_step[identity] = (value, step)

                # Extract the highest-step files and their corresponding identities
                for identity, (file, _) in identity_to_highest_step.items():
                    all_files.append(file)
                    classes.append(identity)
            else:
                all_files = []
                classes = []
                for item in os.listdir(shard_dir):
                    load_adp_weights = torch.load(os.path.join(shard_dir, item), map_location=torch.device('cpu'))
                    files = list(load_adp_weights.values())
                    clss = list(load_adp_weights.keys())
                    all_files.extend(files)
                    classes.extend(clss)
                    if use_one_shard:
                        # print(classes)
                        break
        else:
            if subsample:
                shard_dir = f"/data/user/diffusers/examples/textual_inversion/{data_dir}.pt"
            else:
                shard_dir = "/data/user/diffusers/examples/textual_inversion/val_sd_identity_prompts.pt"
            load_adp_weights = torch.load(shard_dir, map_location=torch.device('cpu'))
            all_files = list(load_adp_weights.values())
            classes = list(load_adp_weights.keys())
        instances = None
        poses=None
        load_adp_weights = torch.load("/data/user/diffusers/examples/lora_inversion/diff_prompt_data_tensor_last_200.pt")
        
        all_files.extend(load_adp_weights)
        load_adp_weights = torch.load("/data/user/CoOp/scripts/coop/coop_cross_dataset_500_class_ctx_values_dict.pt", map_location=torch.device('cpu'))
        all_files.extend(list(load_adp_weights.values()))
        classes2 = list(load_adp_weights.keys())
        
    elif dataset_mode == 'diffuse':
        if is_train:
            if subsample:
                shard_dir = f"/data/user/diffusers/examples/textual_inversion/{data_dir}"
            else: 
                shard_dir = "/data/user/diffusers/examples/textual_inversion/identities_shards"
            
            all_files = []
            classes = []
            # all_files2 = []
            classes2 = []
            # shard_dir = '/data/user/diffusers/examples/textual_inversion/all100identities_shards/'
            # Create a dictionary to track the highest step for each identity
            identity_to_highest_step = defaultdict(lambda: (None, -1))  # (file, highest_step)
            identity_to_other_step = defaultdict(lambda: (None, -1))
            for item in os.listdir(shard_dir):
                # Load the weights
                load_adp_weights = torch.load(os.path.join(shard_dir, item), map_location=torch.device('cpu'))
                for key, value in load_adp_weights.items():
                    # Parse the identity and step from the key
                    # Example key: 'identity-99-step-500'
                    identity, _, step = key.rsplit('-', 2)
                    # print(identity, step)
                    step = int(step.split('-')[-1])  # Extract step as integer
                    
                    # Update the dictionary if this step is higher
                    if step > identity_to_highest_step[identity][1]:
                        identity_to_highest_step[identity] = (value, step)
                    identity_to_other_step[key] = (value, step)

            all_files2 = {}
            # Extract the highest-step files and their corresponding identities
            for identity, (file, step) in identity_to_highest_step.items():
                del identity_to_other_step[f'{identity}-step-{step}']
                # all_files2.append(file)
                classes2.append(identity)
                all_files2[identity] = file
            
            for identity, (file, _) in identity_to_other_step.items():
                all_files.append(file)
                classes.append(identity)
        else:
            shard_dir = "/data/user/diffusers/examples/textual_inversion/val_sd_identity_prompts_updated.pt"
            load_adp_weights = torch.load(shard_dir, map_location=torch.device('cpu'))
            all_files2 = load_adp_weights #list(load_adp_weights.values())
            classes2 = list(load_adp_weights.keys())

            shard_dir = "/data/user/diffusers/examples/textual_inversion/val_sd_step480_identity_prompts.pt"
            load_adp_weights = torch.load(shard_dir, map_location=torch.device('cpu'))
            all_files = list(load_adp_weights.values())
            classes = list(load_adp_weights.keys())            
        instances = None
        poses=None
        load_adp_weights = torch.load("/data/user/diffusers/examples/lora_inversion/diff_prompt_data_tensor_last_200.pt")
        
        with open('/data/user/diffusers/examples/lora_inversion/diff_prompt_data_captions_last_200.json', 'r') as f:
            data = json.load(f)
        load_adp_names = list(data.values())
        # New lists after removing 200th and 201st elements
        filtered_tensors = []

        # Lists to store removed 201st elements
        removed_tensors = []
        assert len(load_adp_weights) == len(load_adp_names)
        # Iterate over the indices
        for i in range(len(load_adp_weights)):
            if 'nudity' == load_adp_names[i]:
                continue
            if (i + 1) % 200 == 0:
                continue  # Skip the 200th element
            elif (i + 1) % 201 == 0:
                removed_tensors.append(load_adp_weights[i])
                all_files2[load_adp_names[i]] = load_adp_weights[i].unsqueeze(0)
                continue  # Skip the 201st element from the main list

            filtered_tensors.append(load_adp_weights[i])
        all_files.extend(filtered_tensors)
        # all_files2.extend(removed_tensors)
    else:
        raise NotImplementedError('{} not implemented'.format(dataset_mode))

    print("Len of Dataset:", len(all_files))
    # print(MPI.COMM_WORLD.Get_rank())
    # print(MPI.COMM_WORLD.Get_size())

    dataset = ImageDataset(
        dataset_mode,
        image_size,
        all_files,
        all_files2=all_files2,
        classes=classes,
        classes2=classes2,
        instances=instances,
        poses=poses,
        shard=0,
        num_shards=1,
        data_dir=data_dir,
        random_crop=random_crop,
        random_flip=random_flip,
        use_pose=use_pose,
        use_map=use_map,
        use1d=use1d,
        use2d=use2d,
        use_ti=use_ti,
        usemeta=usemeta,
        coop_loss=coop_loss,
        dataset_name=dataset_name,
        use_null_prompt=use_null_prompt,
        image_adapter=image_adapter,
        random_conditioning=random_conditioning,
        is_train=is_train
    )

    return dataset


def _list_image_files_recursively(data_dir):
    results = []
    for entry in sorted(bf.listdir(data_dir)):
        full_path = bf.join(data_dir, entry)
        ext = entry.split(".")[-1]
        if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
            results.append(full_path)
        elif bf.isdir(full_path):
            results.extend(_list_image_files_recursively(full_path))
    return results

def getClassName(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return "None"


class ImageDataset(Dataset):
    def __init__(
        self,
        dataset_mode,
        resolution,
        image_paths,
        all_files2=None,
        classes=None,
        classes2=None,
        instances=None,
        poses=None,
        shard=0,
        num_shards=1,
        data_dir=None,
        random_crop=False,
        random_flip=True,
        use_pose=False,
        use_map=False,
        use1d=False,
        use2d=False,
        use_ti=False,
        usemeta=False,
        coop_loss=False,
        dataset_name=False,
        use_null_prompt=True,
        image_adapter=False,
        random_conditioning=True,
        is_train=True
    ):
        super().__init__()
        self.is_train = is_train
        self.use1d = use1d
        self.use2d = use2d
        self.use_ti = use_ti
        self.usemeta = usemeta
        self.dataset_mode = dataset_mode
        self.resolution = resolution
        self.coop_loss = coop_loss
        if self.dataset_mode == 'weights' or self.dataset_mode == 'prompts' or self.dataset_mode == 'ti_prompts' or self.dataset_mode == 'cocoop' or self.dataset_mode == 'coop' or dataset_mode == 'identity_prompts' or dataset_mode == 'sd_concept_prompts' or dataset_mode == 'joint' or dataset_mode == 'clip_words':
            self.local_images = image_paths
        else:
            self.local_images = image_paths[shard:][::num_shards]
        if self.dataset_mode == 'cocoop':
            self.local_classes = None
        elif self.dataset_mode == 'coop':
            self.local_classes = classes
        else:
            self.local_classes = None if classes is None else classes[shard:][::num_shards]
        if self.dataset_mode == 'coop' and instances is not None:
            self.local_instances = instances
        else:
            self.local_instances = None if instances is None else instances[shard:][::num_shards]
        if dataset_mode == 'diffuse' or dataset_mode == 'coop_variations' or dataset_mode == 'prompt_diffusion':
            self.local_images2 = all_files2
        self.local_poses = None if poses is None else poses[shard:][::num_shards]
        self.random_crop = random_crop
        self.random_flip = random_flip
        self.use_null_prompt = use_null_prompt
        self.use_pose = use_pose
        self.use_map = use_map
        self.data_dir = data_dir
        self.dataset_name = dataset_name
        self.image_adapter = image_adapter
        self.random_conditioning = random_conditioning
        
        if dataset_mode == 'prompts':
            self.captions = None
        elif dataset_mode in ['ti_prompts', 'identity_prompts', 'sd_concept_prompts', 'clip_words', 'coop_variations']:
            self.captions = classes
        elif dataset_mode == 'coop' or dataset_mode == 'prompt_diffusion':
            with open('/data/user/CoOp/data/combied_class_label_map.pkl', "rb") as f:
                data = pickle.load(f)
            self.combined_class_label_map = data
            self.captions = []
            for idx, clss in enumerate(classes):
                class_captions = clss.split("-") #seed{num_s}-{fold}_classes_{num_c}
                seed = class_captions[0]
                fold = class_captions[1]
                caption = None
                if "increasing" in fold:
                    txt_path = f"/data/user/CoOp/data/imagenet/increasing_classes/{fold}.txt"
                elif "10_random" in fold:
                    txt_path = f"/data/user/CoOp/data/imagenet/10_random_classes/{fold}.txt"
                elif '10_random_all_ctx_meta' in data_dir:
                    txt_path = f"/data/user/CoOp/data/imagenet/10_random_classes/10_{fold}.txt"
                elif "random" in fold:
                    txt_path = f"/data/user/CoOp/data/imagenet/random_classes/{fold}.txt"
                elif "imagenet" in fold:
                    txt_path = f"/data/user/CoOp/data/imagenet/increasing_classes/increasing_classes_1000.txt"
                elif "caltech" in fold:
                    classnames = ['accordion', 'airplane', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'face', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'leopard', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'motorbike', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang']
                    caption = classnames
                elif "pets" in fold:
                    classnames = ['abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'bengal', 'birman', 'bombay', 'boxer', 'british_shorthair', 'chihuahua', 'egyptian_mau', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'maine_coon', 'miniature_pinscher', 'newfoundland', 'persian', 'pomeranian', 'pug', 'ragdoll', 'russian_blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'siamese', 'sphynx', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']
                    caption = classnames
                elif "sun397" in fold:
                    classnames = ['abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', 'amusement_arcade', 'amusement_park', 'anechoic_chamber', 'aquarium', 'aqueduct', 'arch', 'archive', 'art_gallery', 'art_school', 'art_studio', 'assembly_line', 'attic', 'auditorium', 'auto_factory', 'backseat car_interior', 'badlands', 'baggage_claim', 'ball_pit', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', 'barrel_storage wine_cellar', 'baseball stadium', 'baseball_field', 'basement', 'basilica', 'bathroom', 'batters_box', 'bayou', 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', 'block waterfall', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', 'botanical_garden', 'bottle_storage wine_cellar', 'bowling_alley', 'boxing_ring', 'bridge', 'broadleaf forest', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', 'butchers_shop', 'butte', 'cafeteria', 'campsite', 'campus', 'candy_store', 'canyon', 'carrousel', 'castle', 'catacomb', 'cemetery', 'chalet', 'cheese_factory', 'chemistry_lab', 'childs_room', 'classroom', 'clean_room', 'cliff', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', 'conference_center', 'conference_room', 'construction_site', 'control_room', 'coral_reef underwater', 'corn_field', 'corral', 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', 'creek', 'crevasse', 'crosswalk', 'cultivated field', 'dam', 'delicatessen', 'dentists_office', 'dining_car', 'dining_room', 'discotheque', 'dock', 'door elevator', 'dorm_room', 'driveway', 'drugstore', 'east_asia temple', 'electrical_substation', 'elevator_shaft', 'engine_room', 'establishment poolroom', 'excavation', 'exterior balcony', 'exterior covered_bridge', 'exterior gazebo', 'fairway', 'fan waterfall', 'fastfood_restaurant', 'fire_escape', 'fire_station', 'fishpond', 'food_court', 'football stadium', 'forest_path', 'forest_road', 'formal_garden', 'fountain', 'frontseat car_interior', 'galley', 'game_room', 'garbage_dump', 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'hayfield', 'heliport', 'herb_garden', 'highway', 'hill', 'home dinette', 'home poolroom', 'home_office', 'hospital', 'hospital_room', 'hot_spring', 'hotel_room', 'house', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', 'iceberg', 'igloo', 'indoor apse', 'indoor badminton_court', 'indoor bazaar', 'indoor bistro', 'indoor booth', 'indoor bow_window', 'indoor brewery', 'indoor casino', 'indoor cathedral', 'indoor cavern', 'indoor chicken_coop', 'indoor church', 'indoor cloister', 'indoor diner', 'indoor escalator', 'indoor factory', 'indoor firing_range', 'indoor florist_shop', 'indoor garage', 'indoor general_store', 'indoor greenhouse', 'indoor gymnasium', 'indoor hangar', 'indoor ice_skating_rink', 'indoor jacuzzi', 'indoor jail', 'indoor kennel', 'indoor library', 'indoor market', 'indoor mosque', 'indoor movie_theater', 'indoor museum', 'indoor parking_garage', 'indoor pilothouse', 'indoor podium', 'indoor pub', 'indoor shopping_mall', 'indoor stage', 'indoor swimming_pool', 'indoor synagogue', 'indoor tennis_court', 'indoor volleyball_court', 'indoor warehouse', 'indoor wrestling_ring', 'indoor_procenium theater', 'indoor_seats theater', 'industrial_area', 'interior balcony', 'interior elevator', 'islet', 'jail_cell', 'jewelry_shop', 'kasbah', 'kindergarden_classroom', 'kitchen', 'kitchenette', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', 'lift_bridge', 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', 'locker_room', 'mansion', 'manufactured_home', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain', 'mountain_snowy', 'music_store', 'music_studio', 'natural canal', 'natural lake', 'needleleaf forest', 'nursery', 'oast_house', 'ocean', 'office', 'office cubicle', 'office_building', 'oilrig', 'operating_room', 'orchard', 'outdoor apartment_building', 'outdoor arrival_gate', 'outdoor athletic_field', 'outdoor basketball_court', 'outdoor bazaar', 'outdoor bow_window', 'outdoor cabin', 'outdoor cathedral', 'outdoor chicken_coop', 'outdoor church', 'outdoor control_tower', 'outdoor diner', 'outdoor doorway', 'outdoor driving_range', 'outdoor general_store', 'outdoor greenhouse', 'outdoor hangar', 'outdoor hot_tub', 'outdoor hotel', 'outdoor hunting_lodge', 'outdoor ice_skating_rink', 'outdoor inn', 'outdoor kennel', 'outdoor labyrinth', 'outdoor library', 'outdoor lido_deck', 'outdoor market', 'outdoor monastery', 'outdoor mosque', 'outdoor nuclear_power_plant', 'outdoor observatory', 'outdoor oil_refinery', 'outdoor outhouse', 'outdoor parking_garage', 'outdoor planetarium', 'outdoor podium', 'outdoor power_plant', 'outdoor swimming_pool', 'outdoor synagogue', 'outdoor tennis_court', 'outdoor tent', 'outdoor track', 'outdoor volleyball_court', 'pagoda', 'palace', 'pantry', 'park', 'parking_lot', 'parlor', 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', 'physics_laboratory', 'picnic_area', 'platform subway_station', 'platform train_station', 'playground', 'playroom', 'plaza', 'plunge waterfall', 'pond', 'promenade_deck', 'public atrium', 'pulpit', 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sand desert', 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', 'shed', 'shoe_shop', 'shop bakery', 'shopfront', 'shower', 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'south_asia temple', 'squash_court', 'stable', 'staircase', 'street', 'subway_interior', 'supermarket', 'sushi_bar', 'swamp', 'television_studio', 'thriftshop', 'throne_room', 'ticket_booth', 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'train_railway', 'tree_farm', 'tree_house', 'trench', 'urban canal', 'utility_room', 'valley', 'van_interior', 'vegetable_garden', 'vegetation desert', 'vehicle dinette', 'veranda', 'veterinarians_office', 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', 'waiting_room', 'water moat', 'water_tower', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', 'wild field', 'wind_farm', 'windmill', 'yard', 'youth_hostel']
                    caption = classnames
                elif "ucf101" in fold:
                    classnames = ['Apply_Eye_Makeup', 'Apply_Lipstick', 'Archery', 'Baby_Crawling', 'Balance_Beam', 'Band_Marching', 'Baseball_Pitch', 'Basketball', 'Basketball_Dunk', 'Bench_Press', 'Biking', 'Billiards', 'Blow_Dry_Hair', 'Blowing_Candles', 'Body_Weight_Squats', 'Bowling', 'Boxing_Punching_Bag', 'Boxing_Speed_Bag', 'Breast_Stroke', 'Brushing_Teeth', 'Clean_And_Jerk', 'Cliff_Diving', 'Cricket_Bowling', 'Cricket_Shot', 'Cutting_In_Kitchen', 'Diving', 'Drumming', 'Fencing', 'Field_Hockey_Penalty', 'Floor_Gymnastics', 'Frisbee_Catch', 'Front_Crawl', 'Golf_Swing', 'Haircut', 'Hammer_Throw', 'Hammering', 'Handstand_Pushups', 'Handstand_Walking', 'Head_Massage', 'High_Jump', 'Horse_Race', 'Horse_Riding', 'Hula_Hoop', 'Ice_Dancing', 'Javelin_Throw', 'Juggling_Balls', 'Jump_Rope', 'Jumping_Jack', 'Kayaking', 'Knitting', 'Long_Jump', 'Lunges', 'Military_Parade', 'Mixing', 'Mopping_Floor', 'Nunchucks', 'Parallel_Bars', 'Pizza_Tossing', 'Playing_Cello', 'Playing_Daf', 'Playing_Dhol', 'Playing_Flute', 'Playing_Guitar', 'Playing_Piano', 'Playing_Sitar', 'Playing_Tabla', 'Playing_Violin', 'Pole_Vault', 'Pommel_Horse', 'Pull_Ups', 'Punch', 'Push_Ups', 'Rafting', 'Rock_Climbing_Indoor', 'Rope_Climbing', 'Rowing', 'Salsa_Spin', 'Shaving_Beard', 'Shotput', 'Skate_Boarding', 'Skiing', 'Skijet', 'Sky_Diving', 'Soccer_Juggling', 'Soccer_Penalty', 'Still_Rings', 'Sumo_Wrestling', 'Surfing', 'Swing', 'Table_Tennis_Shot', 'Tai_Chi', 'Tennis_Swing', 'Throw_Discus', 'Trampoline_Jumping', 'Typing', 'Uneven_Bars', 'Volleyball_Spiking', 'Walking_With_Dog', 'Wall_Pushups', 'Writing_On_Board', 'Yo_Yo']
                    caption = classnames
                elif "eurosat" in fold:
                    classnames = ['Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land', 'Highway or Road', 'Industrial Buildings', 'Pasture Land', 'Permanent Crop Land', 'Residential Buildings', 'River', 'Sea or Lake']
                    caption = classnames
                elif "dtd" in fold:
                    classnames = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged']
                    caption = classnames
                elif "fgvc_aircraft" in fold:
                    classnames = ['DHC-8-100', 'A330-200', '737-900', '737-700', 'CRJ-700', 'A340-600', 'Tornado', '747-100', '747-200', 'Metroliner', 'ATR-42', '757-200', '737-400', 'ATR-72', 'Challenger 600', 'DC-6', 'A320', 'DC-8', 'Cessna 560', 'E-170', 'MD-90', 'MD-80', 'Gulfstream IV', 'Dornier 328', '737-600', 'Boeing 717', '737-500', 'A321', 'Falcon 2000', 'PA-28', '737-800', 'BAE-125', 'Fokker 100', '737-200', 'Cessna 208', 'F-16A/B', 'A319', 'MD-11', 'EMB-120', '747-400', '737-300', 'F/A-18', 'Beechcraft 1900', '767-200', 'A300B4', '747-300', 'SR-20', 'BAE 146-300', 'DHC-1', 'A310', 'Il-76', '777-300', 'ERJ 145', 'Tu-134', 'DC-9-30', 'Spitfire', 'C-47', 'C-130', 'An-12', '767-400', 'CRJ-900', 'Falcon 900', 'Saab 2000', '767-300', 'Embraer Legacy 600', 'Saab 340', 'BAE 146-200', 'Cessna 172', 'DHC-6', 'ERJ 135', 'A340-200', 'E-190', 'A380', 'Yak-42', '757-300', 'Hawk T1', 'DC-3', '707-320', 'A330-300', 'A340-300', 'Tu-154', 'Cessna 525', '777-200', 'DHC-8-300', 'Fokker 70', 'DH-82', 'E-195', 'DR-400', 'L-1011', 'Global Express', 'MD-87', 'A340-500', 'Gulfstream V', 'CRJ-200', 'Model B200', '727-200', 'Eurofighter Typhoon', 'A318', 'DC-10', 'Fokker 50']
                    # classnames = [      "A300",
                    #                     "A310",
                    #                     "A320",
                    #                     "A330",
                    #                     "A340",
                    #                     "A380",
                    #                     "ATR-42",
                    #                     "ATR-72",
                    #                     "An-12",
                    #                     "BAE 146",
                    #                     "BAE-125",
                    #                     "Beechcraft 1900",
                    #                     "Boeing 707",
                    #                     "Boeing 717",
                    #                     "Boeing 727",
                    #                     "Boeing 737",
                    #                     "Boeing 747",
                    #                     "Boeing 757",
                    #                     "Boeing 767",
                    #                     "Boeing 777",
                    #                     "C-130",
                    #                     "C-47",
                    #                     "CRJ-200",
                    #                     "CRJ-700",
                    #                     "Cessna 172",
                    #                     "Cessna 208",
                    #                     "Cessna Citation",
                    #                     "Challenger 600",
                    #                     "DC-10",
                    #                     "DC-3",
                    #                     "DC-6",
                    #                     "DC-8",
                    #                     "DC-9",
                    #                     "DH-82",
                    #                     "DHC-1",
                    #                     "DHC-6",
                    #                     "DR-400",
                    #                     "Dash 8",
                    #                     "Dornier 328",
                    #                     "EMB-120",
                    #                     "Embraer E-Jet",
                    #                     "Embraer ERJ 145",
                    #                     "Embraer Legacy 600",
                    #                     "Eurofighter Typhoon",
                    #                     "F-16",
                    #                     "F/A-18",
                    #                     "Falcon 2000",
                    #                     "Falcon 900",
                    #                     "Fokker 100",
                    #                     "Fokker 50",
                    #                     "Fokker 70",
                    #                     "Global Express",
                    #                     "Gulfstream",
                    #                     "Hawk T1",
                    #                     "Il-76",
                    #                     "King Air",
                    #                     "L-1011",
                    #                     "MD-11",
                    #                     "MD-80",
                    #                     "MD-90",
                    #                     "Metroliner",
                    #                     "PA-28",
                    #                     "SR-20",
                    #                     "Saab 2000",
                    #                     "Saab 340",
                    #                     "Spitfire",
                    #                     "Tornado",
                    #                     "Tu-134",
                    #                     "Tu-154",
                    #                     "Yak-42"
                    #                 ]
                    caption = classnames                
                elif "food101" in fold:
                    classnames = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']
                    caption = classnames
                elif "flowers" in fold:
                    classnames = ['alpine sea holly', 'anthurium', 'artichoke', 'azalea', 'ball moss', 'balloon flower', 'barbeton daisy', 'bearded iris', 'bee balm', 'bird of paradise', 'bishop of llandaff', 'black-eyed susan', 'blackberry lily', 'blanket flower', 'bolero deep blue', 'bougainvillea', 'bromelia', 'buttercup', 'californian poppy', 'camellia', 'canna lily', 'canterbury bells', 'cape flower', 'carnation', 'cautleya spicata', 'clematis', "colt's foot", 'columbine', 'common dandelion', 'corn poppy', 'cyclamen', 'daffodil', 'desert-rose', 'english marigold', 'fire lily', 'foxglove', 'frangipani', 'fritillary', 'garden phlox', 'gaura', 'gazania', 'geranium', 'giant white arum lily', 'globe thistle', 'globe-flower', 'grape hyacinth', 'great masterwort', 'hard-leaved pocket orchid', 'hibiscus', 'hippeastrum', 'japanese anemone', 'king protea', 'lenten rose', 'lotus', 'love in the mist', 'magnolia', 'mallow', 'marigold', 'mexican aster', 'mexican petunia', 'monkshood', 'moon orchid', 'morning glory', 'orange dahlia', 'osteospermum', 'oxeye daisy', 'passion flower', 'pelargonium', 'peruvian lily', 'petunia', 'pincushion flower', 'pink primrose', 'pink-yellow dahlia', 'poinsettia', 'primula', 'prince of wales feathers', 'purple coneflower', 'red ginger', 'rose', 'ruby-lipped cattleya', 'siam tulip', 'silverbush', 'snapdragon', 'spear thistle', 'spring crocus', 'stemless gentian', 'sunflower', 'sweet pea', 'sweet william', 'sword lily', 'thorn apple', 'tiger lily', 'toad lily', 'tree mallow', 'tree poppy', 'trumpet creeper', 'wallflower', 'water lily', 'watercress', 'wild pansy', 'windflower', 'yellow iris']
                    caption = classnames
                elif "stanford" in fold:
                    classnames = ['1991 Volkswagen Golf Hatchback', '1993 Geo Metro Convertible', '1993 Mercedes-Benz 300-Class Convertible', '1993 Volvo 240 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '1994 Audi V8 Sedan', '1997 Dodge Caravan Minivan', '1998 Eagle Talon Hatchback', '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2000 AM General Hummer SUV', '2001 Acura Integra Type R', '2001 Lamborghini Diablo Coupe', '2002 Daewoo Nubira Wagon', '2006 Ford GT Coupe', '2007 Audi S4 Sedan', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2007 Bentley Continental Flying Spur Sedan', '2007 Bentley Continental GT Coupe', '2007 Buick Rainier SUV', '2007 Cadillac Escalade EXT Crew Cab', '2007 Chevrolet Corvette Ron Fellows Edition Z06', '2007 Chevrolet Express Cargo Van', '2007 Chevrolet Express Van', '2007 Chevrolet Impala Sedan', '2007 Chevrolet Malibu Sedan', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Dodge Caliber Wagon', '2007 Dodge Dakota Club Cab', '2007 Dodge Durango SUV', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2007 Ford Freestar Minivan', '2007 Ford Mustang Convertible', '2007 Honda Odyssey Minivan', '2007 Hyundai Elantra Sedan', '2007 Suzuki Aerio Sedan', '2007 Volvo XC90 SUV', '2008 Acura TL Type-S', '2008 Audi RS 4 Convertible', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2008 Dodge Magnum Wagon', '2008 Isuzu Ascender SUV', '2008 Lamborghini Reventon Coupe', '2009 Bentley Arnage Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', '2009 Chevrolet TrailBlazer SS', '2009 Chrysler Aspen SUV', '2009 Dodge Charger SRT-8', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2009 Ford Expedition EL SUV', '2009 HUMMER H2 SUT Crew Cab', '2009 Mercedes-Benz SL-Class Coupe', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2010 BMW M5 Sedan', '2010 BMW M6 Convertible', '2010 Chevrolet Cobalt SS', '2010 Chevrolet HHR SS', '2010 Chevrolet Malibu Hybrid Sedan', '2010 Chrysler 300 SRT-8', '2010 Chrysler Sebring Convertible', '2010 Dodge Dakota Crew Cab', '2010 Dodge Ram Pickup 3500 Crew Cab', '2010 HUMMER H3T Crew Cab', '2011 Audi S6 Sedan', '2011 Audi TT Hatchback', '2011 Bentley Mulsanne Sedan', '2011 Dodge Challenger SRT8', '2011 Ford Ranger SuperCab', '2011 Infiniti QX56 SUV', '2011 Lincoln Town Car Sedan', '2011 Mazda Tribute SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2012 Acura TSX Sedan', '2012 Acura ZDX Hatchback', '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2012 Audi A5 Coupe', '2012 Audi R8 Coupe', '2012 Audi S4 Sedan', '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi TT RS Coupe', '2012 Audi TTS Coupe', '2012 BMW 1 Series Convertible', '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW M3 Coupe', '2012 BMW X3 SUV', '2012 BMW X6 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental GT Coupe', '2012 Bentley Continental Supersports Conv. Convertible', '2012 Buick Enclave SUV', '2012 Buick Regal GS', '2012 Buick Verano Sedan', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2012 Chevrolet Avalanche Crew Cab', '2012 Chevrolet Camaro Convertible', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2012 Chevrolet Silverado 2500HD Regular Cab', '2012 Chevrolet Sonic Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Traverse SUV', '2012 Chrysler Town and Country Minivan', '2012 Dodge Caliber Wagon', '2012 Dodge Charger Sedan', '2012 Dodge Durango SUV', '2012 Dodge Journey SUV', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari 458 Italia Convertible', '2012 Ferrari 458 Italia Coupe', '2012 Ferrari California Convertible', '2012 Ferrari FF Coupe', '2012 Fisker Karma Sedan', '2012 Ford E-Series Wagon Van', '2012 Ford Edge SUV', '2012 Ford F-150 Regular Cab', '2012 Ford F-450 Super Duty Crew Cab', '2012 Ford Fiesta Sedan', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '2012 GMC Savana Van', '2012 GMC Terrain SUV', '2012 GMC Yukon Hybrid SUV', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', '2012 Honda Odyssey Minivan', '2012 Hyundai Accent Sedan', '2012 Hyundai Azera Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Genesis Sedan', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Sonata Hybrid Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Tucson SUV', '2012 Hyundai Veloster Hatchback', '2012 Hyundai Veracruz SUV', '2012 Infiniti G Coupe IPL', '2012 Jaguar XK XKR', '2012 Jeep Compass SUV', '2012 Jeep Grand Cherokee SUV', '2012 Jeep Liberty SUV', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2012 Land Rover LR2 SUV', '2012 Land Rover Range Rover SUV', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2012 McLaren MP4-12C Coupe', '2012 Mercedes-Benz C-Class Sedan', '2012 Mercedes-Benz E-Class Sedan', '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Juke Hatchback', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota 4Runner SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', '2012 Toyota Sequoia SUV', '2012 Volkswagen Beetle Hatchback', '2012 Volkswagen Golf Hatchback', '2012 Volvo C30 Hatchback', '2012 smart fortwo Convertible']
                    caption = classnames
                else:
                    raise NotImplementedError            
                if 'imagenet_full_ctx_values_dict' in data_dir and idx < 3:
                    use_full = True
                else:
                    use_full = False
                if 'increasing_classes_1001' in fold:
                    use_full = True
                # use_full = True
                if caption is None:
                    caption = self.extract_class_names(txt_path, use_full=use_full)
                # print(caption)
                self.captions.append(caption)
        elif dataset_mode == 'joint':
            # if self.data_dir == 'lora_prompts.pt' and self.local_images[0].shape[-1] == 768:
            import json
            with open('/data/user/diffusers/examples/lora_inversion/diff_prompt_data_captions_last_200.json', 'r') as f:
                data = json.load(f)
            self.captions = classes
            self.captions.extend(list(data.values()))
            coop_captions = []
            for idx, clss in enumerate(classes2):
                class_captions = clss.split("-") #seed{num_s}-{fold}_classes_{num_c}
                seed = class_captions[0]
                fold = class_captions[1]
                caption = None
                if "increasing" in fold:
                    txt_path = f"/data/user/CoOp/data/imagenet/increasing_classes/{fold}.txt"
                elif "10_random" in fold:
                    txt_path = f"/data/user/CoOp/data/imagenet/10_random_classes/{fold}.txt"
                elif '10_random_all_ctx_meta' in data_dir:
                    txt_path = f"/data/user/CoOp/data/imagenet/10_random_classes/10_{fold}.txt"
                elif "random" in fold:
                    txt_path = f"/data/user/CoOp/data/imagenet/random_classes/{fold}.txt"
                elif "imagenet" in fold:
                    txt_path = f"/data/user/CoOp/data/imagenet/increasing_classes/increasing_classes_1000.txt"
                elif "caltech" in fold:
                    classnames = ['accordion', 'airplane', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'face', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'leopard', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'motorbike', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang']
                    caption = classnames
                elif "pets" in fold:
                    classnames = ['abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'bengal', 'birman', 'bombay', 'boxer', 'british_shorthair', 'chihuahua', 'egyptian_mau', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'maine_coon', 'miniature_pinscher', 'newfoundland', 'persian', 'pomeranian', 'pug', 'ragdoll', 'russian_blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'siamese', 'sphynx', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']
                    caption = classnames
                elif "sun397" in fold:
                    classnames = ['abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', 'amusement_arcade', 'amusement_park', 'anechoic_chamber', 'aquarium', 'aqueduct', 'arch', 'archive', 'art_gallery', 'art_school', 'art_studio', 'assembly_line', 'attic', 'auditorium', 'auto_factory', 'backseat car_interior', 'badlands', 'baggage_claim', 'ball_pit', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', 'barrel_storage wine_cellar', 'baseball stadium', 'baseball_field', 'basement', 'basilica', 'bathroom', 'batters_box', 'bayou', 'beach', 'beauty_salon', 'bedroom', 'berth', 'biology_laboratory', 'block waterfall', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', 'botanical_garden', 'bottle_storage wine_cellar', 'bowling_alley', 'boxing_ring', 'bridge', 'broadleaf forest', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', 'butchers_shop', 'butte', 'cafeteria', 'campsite', 'campus', 'candy_store', 'canyon', 'carrousel', 'castle', 'catacomb', 'cemetery', 'chalet', 'cheese_factory', 'chemistry_lab', 'childs_room', 'classroom', 'clean_room', 'cliff', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', 'conference_center', 'conference_room', 'construction_site', 'control_room', 'coral_reef underwater', 'corn_field', 'corral', 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', 'creek', 'crevasse', 'crosswalk', 'cultivated field', 'dam', 'delicatessen', 'dentists_office', 'dining_car', 'dining_room', 'discotheque', 'dock', 'door elevator', 'dorm_room', 'driveway', 'drugstore', 'east_asia temple', 'electrical_substation', 'elevator_shaft', 'engine_room', 'establishment poolroom', 'excavation', 'exterior balcony', 'exterior covered_bridge', 'exterior gazebo', 'fairway', 'fan waterfall', 'fastfood_restaurant', 'fire_escape', 'fire_station', 'fishpond', 'food_court', 'football stadium', 'forest_path', 'forest_road', 'formal_garden', 'fountain', 'frontseat car_interior', 'galley', 'game_room', 'garbage_dump', 'gas_station', 'gift_shop', 'golf_course', 'harbor', 'hayfield', 'heliport', 'herb_garden', 'highway', 'hill', 'home dinette', 'home poolroom', 'home_office', 'hospital', 'hospital_room', 'hot_spring', 'hotel_room', 'house', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', 'iceberg', 'igloo', 'indoor apse', 'indoor badminton_court', 'indoor bazaar', 'indoor bistro', 'indoor booth', 'indoor bow_window', 'indoor brewery', 'indoor casino', 'indoor cathedral', 'indoor cavern', 'indoor chicken_coop', 'indoor church', 'indoor cloister', 'indoor diner', 'indoor escalator', 'indoor factory', 'indoor firing_range', 'indoor florist_shop', 'indoor garage', 'indoor general_store', 'indoor greenhouse', 'indoor gymnasium', 'indoor hangar', 'indoor ice_skating_rink', 'indoor jacuzzi', 'indoor jail', 'indoor kennel', 'indoor library', 'indoor market', 'indoor mosque', 'indoor movie_theater', 'indoor museum', 'indoor parking_garage', 'indoor pilothouse', 'indoor podium', 'indoor pub', 'indoor shopping_mall', 'indoor stage', 'indoor swimming_pool', 'indoor synagogue', 'indoor tennis_court', 'indoor volleyball_court', 'indoor warehouse', 'indoor wrestling_ring', 'indoor_procenium theater', 'indoor_seats theater', 'industrial_area', 'interior balcony', 'interior elevator', 'islet', 'jail_cell', 'jewelry_shop', 'kasbah', 'kindergarden_classroom', 'kitchen', 'kitchenette', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', 'lift_bridge', 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', 'locker_room', 'mansion', 'manufactured_home', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', 'motel', 'mountain', 'mountain_snowy', 'music_store', 'music_studio', 'natural canal', 'natural lake', 'needleleaf forest', 'nursery', 'oast_house', 'ocean', 'office', 'office cubicle', 'office_building', 'oilrig', 'operating_room', 'orchard', 'outdoor apartment_building', 'outdoor arrival_gate', 'outdoor athletic_field', 'outdoor basketball_court', 'outdoor bazaar', 'outdoor bow_window', 'outdoor cabin', 'outdoor cathedral', 'outdoor chicken_coop', 'outdoor church', 'outdoor control_tower', 'outdoor diner', 'outdoor doorway', 'outdoor driving_range', 'outdoor general_store', 'outdoor greenhouse', 'outdoor hangar', 'outdoor hot_tub', 'outdoor hotel', 'outdoor hunting_lodge', 'outdoor ice_skating_rink', 'outdoor inn', 'outdoor kennel', 'outdoor labyrinth', 'outdoor library', 'outdoor lido_deck', 'outdoor market', 'outdoor monastery', 'outdoor mosque', 'outdoor nuclear_power_plant', 'outdoor observatory', 'outdoor oil_refinery', 'outdoor outhouse', 'outdoor parking_garage', 'outdoor planetarium', 'outdoor podium', 'outdoor power_plant', 'outdoor swimming_pool', 'outdoor synagogue', 'outdoor tennis_court', 'outdoor tent', 'outdoor track', 'outdoor volleyball_court', 'pagoda', 'palace', 'pantry', 'park', 'parking_lot', 'parlor', 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', 'physics_laboratory', 'picnic_area', 'platform subway_station', 'platform train_station', 'playground', 'playroom', 'plaza', 'plunge waterfall', 'pond', 'promenade_deck', 'public atrium', 'pulpit', 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', 'rainforest', 'reception', 'recreation_room', 'residential_neighborhood', 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin', 'runway', 'sand desert', 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', 'shed', 'shoe_shop', 'shop bakery', 'shopfront', 'shower', 'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'south_asia temple', 'squash_court', 'stable', 'staircase', 'street', 'subway_interior', 'supermarket', 'sushi_bar', 'swamp', 'television_studio', 'thriftshop', 'throne_room', 'ticket_booth', 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'train_railway', 'tree_farm', 'tree_house', 'trench', 'urban canal', 'utility_room', 'valley', 'van_interior', 'vegetable_garden', 'vegetation desert', 'vehicle dinette', 'veranda', 'veterinarians_office', 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', 'waiting_room', 'water moat', 'water_tower', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', 'wild field', 'wind_farm', 'windmill', 'yard', 'youth_hostel']
                    caption = classnames
                elif "ucf101" in fold:
                    classnames = ['Apply_Eye_Makeup', 'Apply_Lipstick', 'Archery', 'Baby_Crawling', 'Balance_Beam', 'Band_Marching', 'Baseball_Pitch', 'Basketball', 'Basketball_Dunk', 'Bench_Press', 'Biking', 'Billiards', 'Blow_Dry_Hair', 'Blowing_Candles', 'Body_Weight_Squats', 'Bowling', 'Boxing_Punching_Bag', 'Boxing_Speed_Bag', 'Breast_Stroke', 'Brushing_Teeth', 'Clean_And_Jerk', 'Cliff_Diving', 'Cricket_Bowling', 'Cricket_Shot', 'Cutting_In_Kitchen', 'Diving', 'Drumming', 'Fencing', 'Field_Hockey_Penalty', 'Floor_Gymnastics', 'Frisbee_Catch', 'Front_Crawl', 'Golf_Swing', 'Haircut', 'Hammer_Throw', 'Hammering', 'Handstand_Pushups', 'Handstand_Walking', 'Head_Massage', 'High_Jump', 'Horse_Race', 'Horse_Riding', 'Hula_Hoop', 'Ice_Dancing', 'Javelin_Throw', 'Juggling_Balls', 'Jump_Rope', 'Jumping_Jack', 'Kayaking', 'Knitting', 'Long_Jump', 'Lunges', 'Military_Parade', 'Mixing', 'Mopping_Floor', 'Nunchucks', 'Parallel_Bars', 'Pizza_Tossing', 'Playing_Cello', 'Playing_Daf', 'Playing_Dhol', 'Playing_Flute', 'Playing_Guitar', 'Playing_Piano', 'Playing_Sitar', 'Playing_Tabla', 'Playing_Violin', 'Pole_Vault', 'Pommel_Horse', 'Pull_Ups', 'Punch', 'Push_Ups', 'Rafting', 'Rock_Climbing_Indoor', 'Rope_Climbing', 'Rowing', 'Salsa_Spin', 'Shaving_Beard', 'Shotput', 'Skate_Boarding', 'Skiing', 'Skijet', 'Sky_Diving', 'Soccer_Juggling', 'Soccer_Penalty', 'Still_Rings', 'Sumo_Wrestling', 'Surfing', 'Swing', 'Table_Tennis_Shot', 'Tai_Chi', 'Tennis_Swing', 'Throw_Discus', 'Trampoline_Jumping', 'Typing', 'Uneven_Bars', 'Volleyball_Spiking', 'Walking_With_Dog', 'Wall_Pushups', 'Writing_On_Board', 'Yo_Yo']
                    caption = classnames
                elif "eurosat" in fold:
                    classnames = ['Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land', 'Highway or Road', 'Industrial Buildings', 'Pasture Land', 'Permanent Crop Land', 'Residential Buildings', 'River', 'Sea or Lake']
                    caption = classnames
                elif "dtd" in fold:
                    classnames = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged']
                    caption = classnames
                elif "fgvc_aircraft" in fold:
                    classnames = ['DHC-8-100', 'A330-200', '737-900', '737-700', 'CRJ-700', 'A340-600', 'Tornado', '747-100', '747-200', 'Metroliner', 'ATR-42', '757-200', '737-400', 'ATR-72', 'Challenger 600', 'DC-6', 'A320', 'DC-8', 'Cessna 560', 'E-170', 'MD-90', 'MD-80', 'Gulfstream IV', 'Dornier 328', '737-600', 'Boeing 717', '737-500', 'A321', 'Falcon 2000', 'PA-28', '737-800', 'BAE-125', 'Fokker 100', '737-200', 'Cessna 208', 'F-16A/B', 'A319', 'MD-11', 'EMB-120', '747-400', '737-300', 'F/A-18', 'Beechcraft 1900', '767-200', 'A300B4', '747-300', 'SR-20', 'BAE 146-300', 'DHC-1', 'A310', 'Il-76', '777-300', 'ERJ 145', 'Tu-134', 'DC-9-30', 'Spitfire', 'C-47', 'C-130', 'An-12', '767-400', 'CRJ-900', 'Falcon 900', 'Saab 2000', '767-300', 'Embraer Legacy 600', 'Saab 340', 'BAE 146-200', 'Cessna 172', 'DHC-6', 'ERJ 135', 'A340-200', 'E-190', 'A380', 'Yak-42', '757-300', 'Hawk T1', 'DC-3', '707-320', 'A330-300', 'A340-300', 'Tu-154', 'Cessna 525', '777-200', 'DHC-8-300', 'Fokker 70', 'DH-82', 'E-195', 'DR-400', 'L-1011', 'Global Express', 'MD-87', 'A340-500', 'Gulfstream V', 'CRJ-200', 'Model B200', '727-200', 'Eurofighter Typhoon', 'A318', 'DC-10', 'Fokker 50']
                    # classnames = [      "A300",
                    #                     "A310",
                    #                     "A320",
                    #                     "A330",
                    #                     "A340",
                    #                     "A380",
                    #                     "ATR-42",
                    #                     "ATR-72",
                    #                     "An-12",
                    #                     "BAE 146",
                    #                     "BAE-125",
                    #                     "Beechcraft 1900",
                    #                     "Boeing 707",
                    #                     "Boeing 717",
                    #                     "Boeing 727",
                    #                     "Boeing 737",
                    #                     "Boeing 747",
                    #                     "Boeing 757",
                    #                     "Boeing 767",
                    #                     "Boeing 777",
                    #                     "C-130",
                    #                     "C-47",
                    #                     "CRJ-200",
                    #                     "CRJ-700",
                    #                     "Cessna 172",
                    #                     "Cessna 208",
                    #                     "Cessna Citation",
                    #                     "Challenger 600",
                    #                     "DC-10",
                    #                     "DC-3",
                    #                     "DC-6",
                    #                     "DC-8",
                    #                     "DC-9",
                    #                     "DH-82",
                    #                     "DHC-1",
                    #                     "DHC-6",
                    #                     "DR-400",
                    #                     "Dash 8",
                    #                     "Dornier 328",
                    #                     "EMB-120",
                    #                     "Embraer E-Jet",
                    #                     "Embraer ERJ 145",
                    #                     "Embraer Legacy 600",
                    #                     "Eurofighter Typhoon",
                    #                     "F-16",
                    #                     "F/A-18",
                    #                     "Falcon 2000",
                    #                     "Falcon 900",
                    #                     "Fokker 100",
                    #                     "Fokker 50",
                    #                     "Fokker 70",
                    #                     "Global Express",
                    #                     "Gulfstream",
                    #                     "Hawk T1",
                    #                     "Il-76",
                    #                     "King Air",
                    #                     "L-1011",
                    #                     "MD-11",
                    #                     "MD-80",
                    #                     "MD-90",
                    #                     "Metroliner",
                    #                     "PA-28",
                    #                     "SR-20",
                    #                     "Saab 2000",
                    #                     "Saab 340",
                    #                     "Spitfire",
                    #                     "Tornado",
                    #                     "Tu-134",
                    #                     "Tu-154",
                    #                     "Yak-42"
                    #                 ]
                    caption = classnames                
                elif "food101" in fold:
                    classnames = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']
                    caption = classnames
                elif "flowers" in fold:
                    classnames = ['alpine sea holly', 'anthurium', 'artichoke', 'azalea', 'ball moss', 'balloon flower', 'barbeton daisy', 'bearded iris', 'bee balm', 'bird of paradise', 'bishop of llandaff', 'black-eyed susan', 'blackberry lily', 'blanket flower', 'bolero deep blue', 'bougainvillea', 'bromelia', 'buttercup', 'californian poppy', 'camellia', 'canna lily', 'canterbury bells', 'cape flower', 'carnation', 'cautleya spicata', 'clematis', "colt's foot", 'columbine', 'common dandelion', 'corn poppy', 'cyclamen', 'daffodil', 'desert-rose', 'english marigold', 'fire lily', 'foxglove', 'frangipani', 'fritillary', 'garden phlox', 'gaura', 'gazania', 'geranium', 'giant white arum lily', 'globe thistle', 'globe-flower', 'grape hyacinth', 'great masterwort', 'hard-leaved pocket orchid', 'hibiscus', 'hippeastrum', 'japanese anemone', 'king protea', 'lenten rose', 'lotus', 'love in the mist', 'magnolia', 'mallow', 'marigold', 'mexican aster', 'mexican petunia', 'monkshood', 'moon orchid', 'morning glory', 'orange dahlia', 'osteospermum', 'oxeye daisy', 'passion flower', 'pelargonium', 'peruvian lily', 'petunia', 'pincushion flower', 'pink primrose', 'pink-yellow dahlia', 'poinsettia', 'primula', 'prince of wales feathers', 'purple coneflower', 'red ginger', 'rose', 'ruby-lipped cattleya', 'siam tulip', 'silverbush', 'snapdragon', 'spear thistle', 'spring crocus', 'stemless gentian', 'sunflower', 'sweet pea', 'sweet william', 'sword lily', 'thorn apple', 'tiger lily', 'toad lily', 'tree mallow', 'tree poppy', 'trumpet creeper', 'wallflower', 'water lily', 'watercress', 'wild pansy', 'windflower', 'yellow iris']
                    caption = classnames
                elif "stanford" in fold:
                    classnames = ['1991 Volkswagen Golf Hatchback', '1993 Geo Metro Convertible', '1993 Mercedes-Benz 300-Class Convertible', '1993 Volvo 240 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '1994 Audi V8 Sedan', '1997 Dodge Caravan Minivan', '1998 Eagle Talon Hatchback', '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2000 AM General Hummer SUV', '2001 Acura Integra Type R', '2001 Lamborghini Diablo Coupe', '2002 Daewoo Nubira Wagon', '2006 Ford GT Coupe', '2007 Audi S4 Sedan', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2007 Bentley Continental Flying Spur Sedan', '2007 Bentley Continental GT Coupe', '2007 Buick Rainier SUV', '2007 Cadillac Escalade EXT Crew Cab', '2007 Chevrolet Corvette Ron Fellows Edition Z06', '2007 Chevrolet Express Cargo Van', '2007 Chevrolet Express Van', '2007 Chevrolet Impala Sedan', '2007 Chevrolet Malibu Sedan', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Dodge Caliber Wagon', '2007 Dodge Dakota Club Cab', '2007 Dodge Durango SUV', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2007 Ford Freestar Minivan', '2007 Ford Mustang Convertible', '2007 Honda Odyssey Minivan', '2007 Hyundai Elantra Sedan', '2007 Suzuki Aerio Sedan', '2007 Volvo XC90 SUV', '2008 Acura TL Type-S', '2008 Audi RS 4 Convertible', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2008 Dodge Magnum Wagon', '2008 Isuzu Ascender SUV', '2008 Lamborghini Reventon Coupe', '2009 Bentley Arnage Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', '2009 Chevrolet TrailBlazer SS', '2009 Chrysler Aspen SUV', '2009 Dodge Charger SRT-8', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2009 Ford Expedition EL SUV', '2009 HUMMER H2 SUT Crew Cab', '2009 Mercedes-Benz SL-Class Coupe', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2010 BMW M5 Sedan', '2010 BMW M6 Convertible', '2010 Chevrolet Cobalt SS', '2010 Chevrolet HHR SS', '2010 Chevrolet Malibu Hybrid Sedan', '2010 Chrysler 300 SRT-8', '2010 Chrysler Sebring Convertible', '2010 Dodge Dakota Crew Cab', '2010 Dodge Ram Pickup 3500 Crew Cab', '2010 HUMMER H3T Crew Cab', '2011 Audi S6 Sedan', '2011 Audi TT Hatchback', '2011 Bentley Mulsanne Sedan', '2011 Dodge Challenger SRT8', '2011 Ford Ranger SuperCab', '2011 Infiniti QX56 SUV', '2011 Lincoln Town Car Sedan', '2011 Mazda Tribute SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2012 Acura TSX Sedan', '2012 Acura ZDX Hatchback', '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2012 Audi A5 Coupe', '2012 Audi R8 Coupe', '2012 Audi S4 Sedan', '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi TT RS Coupe', '2012 Audi TTS Coupe', '2012 BMW 1 Series Convertible', '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW M3 Coupe', '2012 BMW X3 SUV', '2012 BMW X6 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental GT Coupe', '2012 Bentley Continental Supersports Conv. Convertible', '2012 Buick Enclave SUV', '2012 Buick Regal GS', '2012 Buick Verano Sedan', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2012 Chevrolet Avalanche Crew Cab', '2012 Chevrolet Camaro Convertible', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2012 Chevrolet Silverado 2500HD Regular Cab', '2012 Chevrolet Sonic Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Traverse SUV', '2012 Chrysler Town and Country Minivan', '2012 Dodge Caliber Wagon', '2012 Dodge Charger Sedan', '2012 Dodge Durango SUV', '2012 Dodge Journey SUV', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari 458 Italia Convertible', '2012 Ferrari 458 Italia Coupe', '2012 Ferrari California Convertible', '2012 Ferrari FF Coupe', '2012 Fisker Karma Sedan', '2012 Ford E-Series Wagon Van', '2012 Ford Edge SUV', '2012 Ford F-150 Regular Cab', '2012 Ford F-450 Super Duty Crew Cab', '2012 Ford Fiesta Sedan', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '2012 GMC Savana Van', '2012 GMC Terrain SUV', '2012 GMC Yukon Hybrid SUV', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', '2012 Honda Odyssey Minivan', '2012 Hyundai Accent Sedan', '2012 Hyundai Azera Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Genesis Sedan', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Sonata Hybrid Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Tucson SUV', '2012 Hyundai Veloster Hatchback', '2012 Hyundai Veracruz SUV', '2012 Infiniti G Coupe IPL', '2012 Jaguar XK XKR', '2012 Jeep Compass SUV', '2012 Jeep Grand Cherokee SUV', '2012 Jeep Liberty SUV', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2012 Land Rover LR2 SUV', '2012 Land Rover Range Rover SUV', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2012 McLaren MP4-12C Coupe', '2012 Mercedes-Benz C-Class Sedan', '2012 Mercedes-Benz E-Class Sedan', '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Juke Hatchback', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota 4Runner SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', '2012 Toyota Sequoia SUV', '2012 Volkswagen Beetle Hatchback', '2012 Volkswagen Golf Hatchback', '2012 Volvo C30 Hatchback', '2012 smart fortwo Convertible']
                    caption = classnames
                else:
                    raise NotImplementedError            
                if 'imagenet_full_ctx_values_dict' in data_dir and idx < 3:
                    use_full = True
                else:
                    use_full = False
                if 'increasing_classes_1001' in fold:
                    use_full = True
                # use_full = True
                if caption is None:
                    caption = self.extract_class_names(txt_path, use_full=use_full)
                # print(caption)
                coop_captions.append(caption)
                self.captions.extend(caption)
                
        elif dataset_mode == 'diffuse':
            # if self.data_dir == 'lora_prompts.pt' and self.local_images[0].shape[-1] == 768:
            import json
            with open('/data/user/diffusers/examples/lora_inversion/diff_prompt_data_captions_last_200.json', 'r') as f:
                data = json.load(f)
            self.captions = classes
            captions = list(data.values())
            # New lists after removing 200th and 201st elements
            filtered_captions = []

            # Lists to store removed 201st elements
            removed_captions = []

            # Iterate over the indices
            for i in range(len(captions)):
                if 'nudity' == captions[i]:
                    continue
                if (i + 1) % 200 == 0:
                    continue  # Skip the 200th element
                elif (i + 1) % 201 == 0:
                    removed_captions.append(captions[i])
                    continue  # Skip the 201st element from the main list
                
                filtered_captions.append(captions[i])
            self.captions.extend(filtered_captions)
            self.captions2 = classes2
            self.captions2.extend(removed_captions)

        
            
            
    def __len__(self):
        return len(self.local_images)

    def extract_class_names(self, file_path, use_full=False):
        class_names = []
        with open(file_path, 'r') as file:
            for line in file:
                parts = line.split(' ', 1)
                if len(parts) == 2:
                    class_names.append(parts[1].strip())
        if not use_full:
            num_c = len(class_names)
            return class_names[:num_c//2]
        return class_names

    def __getitem__(self, idx):
        if self.dataset_mode == 'cocoop' or self.dataset_mode == 'coop' or self.dataset_mode == 'ti_prompts' or self.dataset_mode == 'identity_prompts' or self.dataset_mode == 'sd_concept_prompts' or self.dataset_mode == 'joint':
            data_dict = self.local_images[idx]
            # class_captions = self.local_classes[idx].split("-") #seed{num_s}-{fold}_classes_{num_c}
            # seed = class_captions[0]
            # fold = class_captions[1]
            # if "increasing" in fold:
            #     txt_path = f"/data/user/CoOp/data/imagenet/increasing_classes/{fold}.txt"
            # elif "random" in fold:
            #     txt_path = f"/data/user/CoOp/data/imagenet/random_classes/{fold}.txt"
            # else:
            #     raise NotImplementedError
            # txt_path = "/data/user/CoOp/data/imagenet/increasing_classes/increasing_classes_1000.txt"
            # caption = self.extract_class_names(txt_path)
            caption = self.captions[idx]
            data = data_dict.detach()
            # print(data.shape, len(caption))
            if len(data.shape) == 1 or data.shape[0] == 4:
                data = data.reshape(1, -1)
            if data.shape[-1] < 2048:
                data = torch.cat([data,torch.zeros((1,1280))],-1)
            
            if self.usemeta:
                data = data[4:, :]
            # print("After", data.shape, len(caption))
            if self.use2d:
                data = torch.cat([data, torch.zeros(4, 64)], -1).reshape(48, 48, 1)
            if self.dataset_mode == 'identity_prompts' or self.dataset_mode == 'sd_concept_prompts' or self.dataset_mode == 'joint':
                if not 'step' in caption:
                    caps = caption if random.random() <= 0.8 else ""
                else:
                    caption = caption.split('-')
                    caps = '-'.join(caption[:-2]) if random.random() <= 0.8 else ""
            elif self.use_ti:
                caption = ''.join(caption)
                caption = caption.split('-')
                caps = ' '.join(caption) if random.random() <= 0.8 else ""
            elif self.dataset_name:
                capp = self.local_classes[idx]
                dsname = capp.split("-")[1].split('_classes')[0]
                caps = dsname if random.random() <= 0.8 else ""
            else:
                caps = ';;'.join(caption) if random.random() <= 0.8 else ""
            sample = {
                'image': data,
                # 'caption': ','.join(caption[:10]) if random.random() <= 0.8 else "",
                'caption': caps,
                # 'caption': ';;'.join(caption),
                "label": data
                }
            if self.image_adapter:
                if self.is_train:
                    idn = caption[-3]
                else:
                    idn = caption[-2]
                ddir = f'/data/CelebA-HQ/organized_images/{idn}'
                image_paths = sorted(os.listdir(ddir))
                if self.random_conditioning:
                    image_path = random.choice(image_paths) 
                else:
                    image_path = image_paths[0]
                # preprocess_clip_ = preprocess_clip()
                # image = preprocess_clip_(Image.open(os.path.join(ddir, image_path)))#.to(device)
                # sample['cond'] = image
                sample['cond'] = np.array(Image.open(os.path.join(ddir, image_path)))
                if self.use_null_prompt:
                    sample['caption'] = ""
                else:
                    pass
                    
            if self.local_instances is not None:
                sample["latent"] = self.local_instances[idx].detach()
            if self.coop_loss:
                from dassl.utils import read_image
                preprocessed, coop_model_path = map_idx_to_filename(idx)#os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
                # print(preprocessed)
                if os.path.exists(preprocessed):
                    # print(f"Loading preprocessed few-shot data from {preprocessed}")
                    with open(preprocessed, "rb") as file:
                        data = pickle.load(file)
                        # train, val = data["train"], data["train"] #data["val"]
                    # new_data = subsample(data["train"])
                    train = random.choice(data["train"])
                    impath = train.impath
                    coop_label = train.label
                    
                    # labels = set()
                    # for item in data["train"]:
                    #     labels.add(item.classname)
                    # coop_cnames = list(labels)
                    coop_cnames = get_classnames_all(impath, self.combined_class_label_map)
                    
                    # print(coop_cnames)
                    coop_cname = train.classname
                    img0 = read_image(impath)
                    # Build transform that doesn't apply any data augmentation
                    interp_mode = INTERPOLATION_MODES['bicubic']
                    to_tensor = []
                    to_tensor += [T.Resize((224, 224), interpolation=interp_mode)]
                    to_tensor += [T.ToTensor()]
                    # if "normalize" in cfg.INPUT.TRANSFORMS:
                    
                    normalize = T.Normalize(
                        mean=PIXEL_MEAN, std=PIXEL_STD
                    )
                    to_tensor += [normalize]
                    self.to_tensor = T.Compose(to_tensor)
                    if False:#self.transform is not None:
                        if isinstance(self.transform, (list, tuple)):
                            for i, tfm in enumerate(self.transform):
                                img = self._transform_image(tfm, img0)
                                keyname = "img"
                                if (i + 1) > 1:
                                    keyname += str(i + 1)
                                output[keyname] = img
                        else:
                            img = self._transform_image(self.transform, img0)
                            output["img"] = img
                    else:
                        img0 = self.to_tensor(img0)
                    # Pad coop_classnames to equal lengths and get original lengths
                    padded_coop_classnames, coop_classnames_lengths = pad_coop_classnames(coop_cnames)

                    # if self.is_train:
                    sample["coop"] = {
                        "coop_img": img0,
                        "coop_label": coop_label,
                        "coop_classnames": padded_coop_classnames,
                        'coop_classnames_length': coop_classnames_lengths,
                        "coop_model_path": coop_model_path,
                    }
                    # else:
                    #     sample["coop"] = random.choice(val)
                    # print(sample["coop"])
            return sample
        elif self.dataset_mode == 'diffuse':
            data_dict = self.local_images[idx]
            
            caption = self.captions[idx]
            data = data_dict.detach()
            # print(data.shape, len(caption))
            if len(data.shape) == 1:
                data = data.reshape(1, -1)
            
            if not 'step' in caption:
                capp = caption
                caps = capp if random.random() <= 0.8 else ""
            else:
                caption = caption.split('-')
                capp = '-'.join(caption[:-2])
                caps = capp if random.random() <= 0.8 else ""
            # print(self.local_images2.keys())
            data_dict2 = self.local_images2[capp]
            data2 = data_dict2.detach().view(data.shape)
            sample = {
                'image': data,
                # 'caption': ','.join(caption[:10]) if random.random() <= 0.8 else "",
                'caption': caps,
                # 'caption': ';;'.join(caption),
                "label": data2
                }
            return sample
        elif self.dataset_mode == 'prompt_diffusion':
            data_dict = self.local_images[idx]
            
            capp = self.captions[idx]
            data = data_dict.detach().view(1, -1)
            # print(data.shape, len(caption))
            
            # caption = caption.split('-')
            # capp = '-'.join(caption[1:])
            caps = capp if random.random() <= 0.8 else ""
            # print(self.local_images2.keys())
            data_dict2 = self.local_images2[idx]
            data2 = data_dict2.detach().view(1, -1)
            # print(data.shape, data2.shape)
            sample = {
                'image': data,
                # 'caption': ','.join(caption[:10]) if random.random() <= 0.8 else "",
                # 'caption': caps,
                # 'caption': ';;'.join(caption),
                "label": data2
                }
            return sample
        elif self.dataset_mode == 'coop_variations':
            data_dict = self.local_images[idx]
            
            caption = self.captions[idx]
            data = data_dict.detach().view(1, -1)
            # print(data.shape, len(caption))
            
            caption = caption.split('-')
            capp = '-'.join(caption[1:])
            caps = capp if random.random() <= 0.8 else ""
            # print(self.local_images2.keys())
            data_dict2 = self.local_images2[capp]
            data2 = data_dict2.detach().view(data.shape)
            sample = {
                'image': data,
                # 'caption': ','.join(caption[:10]) if random.random() <= 0.8 else "",
                'caption': caps,
                # 'caption': ';;'.join(caption),
                "label": data2
                }
            return sample
        elif self.dataset_mode == 'prompts':
            data_dict = self.local_images[idx]
            data = data_dict["bin_content"].detach()
            caption = data_dict["caption"]
            sample = {
                'image': data.reshape(1, -1),
                'caption': caption if random.random() <= 0.8 else "",
                "label": data
                }
            return sample


def resize_arr(pil_list, image_size, keep_aspect=True):
    # We are not on a new enough PIL to support the `reducing_gap`
    # argument, which uses BOX downsampling at powers of two first.
    # Thus, we do it by hand to improve downsample quality.
    pil_image, pil_class, pil_instance, pil_pose = pil_list

    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    if keep_aspect:
        scale = image_size / min(*pil_image.size)
        pil_image = pil_image.resize(
            tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
        )
    else:
        pil_image = pil_image.resize((image_size, image_size), resample=Image.BICUBIC)

    pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST)
    if pil_instance is not None:
        pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST)
    
    if pil_pose is not None:
        pil_pose = pil_pose.resize(pil_image.size, resample=Image.NEAREST)

    arr_image = np.array(pil_image)
    arr_class = np.array(pil_class)
    arr_instance = np.array(pil_instance) if pil_instance is not None else None
    arr_pose = np.array(pil_pose) if pil_pose is not None else None
    return arr_image, arr_class, arr_instance, arr_pose


def center_crop_arr(pil_list, image_size):
    # We are not on a new enough PIL to support the `reducing_gap`
    # argument, which uses BOX downsampling at powers of two first.
    # Thus, we do it by hand to improve downsample quality.
    pil_image, pil_class, pil_instance, pil_pose = pil_list

    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST)
    if pil_instance is not None:
        pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST)
    
    if pil_pose is not None:
        pil_pose = pil_pose.resize(pil_image.size, resample=Image.NEAREST)

    arr_image = np.array(pil_image)
    arr_class = np.array(pil_class)
    arr_instance = np.array(pil_instance) if pil_instance is not None else None
    arr_pose = np.array(pil_pose) if pil_pose is not None else None
    crop_y = (arr_image.shape[0] - image_size) // 2
    crop_x = (arr_image.shape[1] - image_size) // 2
    return arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size],\
           arr_class[crop_y: crop_y + image_size, crop_x: crop_x + image_size],\
           arr_instance[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_instance is not None else None,\
           arr_pose[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_pose is not None else None


def random_crop_arr(pil_list, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
    min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
    max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
    smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)

    # We are not on a new enough PIL to support the `reducing_gap`
    # argument, which uses BOX downsampling at powers of two first.
    # Thus, we do it by hand to improve downsample quality.
    pil_image, pil_class, pil_instance, pil_pose = pil_list

    while min(*pil_image.size) >= 2 * smaller_dim_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = smaller_dim_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST)
    if pil_instance is not None:
        pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST)
    if pil_pose is not None:
        pil_pose = pil_pose.resize(pil_image.size, resample=Image.NEAREST)

    arr_image = np.array(pil_image)
    arr_class = np.array(pil_class)
    arr_instance = np.array(pil_instance) if pil_instance is not None else None
    arr_pose = np.array(pil_pose) if pil_pose is not None else None
    crop_y = random.randrange(arr_image.shape[0] - image_size + 1)
    crop_x = random.randrange(arr_image.shape[1] - image_size + 1)
    return arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size],\
           arr_class[crop_y: crop_y + image_size, crop_x: crop_x + image_size],\
           arr_instance[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_instance is not None else None,\
           arr_pose[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_pose is not None else None
