import torch 
import torchvision.transforms as transforms 
import os
import numpy as np  
import torchvision  
import random   
from PIL import Image   
from resnet import *        
from models import ConvNet
from transformers import CLIPTokenizer
import argparse 
import torchvision
from tqdm import tqdm       
import torchvision.models as models
from torch.utils.data import random_split
from torch.utils.data import Dataset  
from torchvision.datasets import DatasetFolder      
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from torchvision.datasets.folder import default_loader
from pathlib import Path
from torchvision.datasets import ImageFolder  
# from sklearn.model_selection import train_test_split 
from torch.utils.data import  Subset  
import json     
from diffusers import AutoPipelineForImage2Image, StableDiffusionXLPipeline
# from customXLPipeline import CustomXLPipeline   
import torch.distributed as dist
from alexnet import AlexNet, AlexNetCIFAR
from vgg import VGG11
from vgg_cifar import VGG11CIFAR
from vit import ViT



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")



class ImageFolderDiffusion(Dataset):  
    def __init__(self, root, dataset_dic, pipe, samples_per_class=-1, transform=None):  
        self.samples_per_class = samples_per_class  
        self.transform = transform  

        self.pipe = pipe        

        self.dataset_dic = torch.load(dataset_dic)           
  
        # Load the dataset using ImageFolder  
        self.dataset = ImageFolder(root)  

        # self.class_names = sorted([f for f in os.listdir(root) if f.startswith('.') == False])       
          
        # Organize the samples by class  
        self.class_to_samples = {}  
        for path, label in self.dataset.samples:  
            if label not in self.class_to_samples:  
                self.class_to_samples[label] = []  
            self.class_to_samples[label].append(path)  
        
        if self.samples_per_class is None:
            self.samples_per_class = len(self.class_to_samples[0])         
            
             
        # Sample the specified number of images per class  
        self.samples = []  
        for label, paths in self.class_to_samples.items():  
            if samples_per_class != -1:
                if len(paths) >= samples_per_class:
                    sampled_paths = random.sample(paths, samples_per_class)  
                elif len(paths) < samples_per_class:
                    sampled_paths = random.sample(paths, len(paths))            
                    print(f"Class {label} has less than {samples_per_class} samples. Using {len(sampled_paths)} samples.")         

            else:
                sampled_paths = paths
                print('no subsampling')
            

            self.samples.extend([(path, label) for path in sampled_paths])  
  
    def __len__(self):  
        return len(self.samples)  
  
    def __getitem__(self, index):  
        path, label = self.samples[index]  
        image = Image.open(path).convert('RGB').resize((320, 320))        

        cls_name = path.split("/")[-2]      
        prompt = f'a photo of a {self.dataset_dic[cls_name].replace("_", " ")}'
        augmentation = image

        # print(prompt)   
        with torch.no_grad():       
            augmentation = self.pipe(prompt, image=image, num_inference_steps=5, strength=.9, guidance_scale=0., num_images_per_prompt=1, output_type='pil').images[0]
        # augmentation = image

        if self.transform is not None:  
            image = self.transform(augmentation)  

        
          
        return image, label  
    

class ChopCollageTrans(torch.nn.Module):
    def __init__(self, size, factor):
        super().__init__()
        self.factor = factor
        self.bbs = [(i, j, size//factor, size//factor) 
               for i in range(0, size, size // factor) for j in range(0, size, size // factor)]
                

    def forward(self, image):
        image = image.unsqueeze(0)      
        b, c, h, w = image.shape        
        
        patches = [torchvision.transforms.functional.crop(image, *bb) for bb in self.bbs]       
        patches = torch.stack(patches, 0)       
    
        return patches   

    def __repr__(self) -> str:
        detail = f"(num_crop={self.num_crop}, size={self.size})"
        return f"{self.__class__.__name__}{detail}"



def center_crop_PIL( img, crop_width, crop_height):  
    """  
    Center crop an image.  
  
    :param image_path: Path to the input image.  
    :param output_path: Path where the cropped image will be saved.  
    :param crop_width: Width of the crop box.  
    :param crop_height: Height of the crop box.  
    """  

    img_width, img_height = img.size  
        
    # Calculate the coordinates of the crop box  
    left = (img_width - crop_width) // 2  
    top = (img_height - crop_height) // 2  
    right = (img_width + crop_width) // 2  
    bottom = (img_height + crop_height) // 2  
        
    # Perform the crop  
    cropped_img = img.crop((left, top, right, bottom))  
        
    return cropped_img   



class JSONDataset(Dataset):  
    def __init__(self, json_file, root_dir, transform=None, target_transform=None):  
        with open(json_file, 'r') as f:  
            self.img_paths = json.load(f)  
        self.classes = sorted({self._get_class_name(path) for path in self.img_paths})  
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}  
        for k, v in self.class_to_idx.items():  
            print(f'{k}: {v}')  

        #remove before class name
    
        self.image_path_root_concat = [os.path.join(root_dir, f.split('/')[-2], f.split('/')[-1]) for f in self.img_paths]          
        print(self.image_path_root_concat)

        self.transform = transform  
        self.target_transform = target_transform  
  
    def __len__(self):  
        return len(self.img_paths)  
  
    def __getitem__(self, idx):  
        img_path = self.image_path_root_concat[idx]  
        image = Image.open(img_path).convert('RGB')  
        label = self.class_to_idx[self._get_class_name(img_path)]  
        if self.transform:  
            image = self.transform(image)  
        if self.target_transform:  
            label = self.target_transform(label)  
        return image, label  
  
    @staticmethod  
    def _get_class_name(path):  
        return os.path.basename(os.path.dirname(path))  
    
  
def split_imagefolder_dataset(dataset, eval_ratio, seed):  
    """  
    Splits ImageFolder dataset into train and eval sets.  
      
    Args:  
        root_dir (str): Root directory path for ImageFolder.  
        eval_ratio (float): Ratio of evaluation data.  
        seed (int): Random seed for reproducibility.  
  
    Returns:  
        train_dataset (Subset): Training dataset.  
        eval_dataset (Subset): Evaluation dataset.  
    """  
    
      
    # Get all indices  
    indices = list(range(len(dataset)))  
      
    # Split the indices  
    train_indices, eval_indices = train_test_split(  
        indices, test_size=eval_ratio, random_state=seed  
    )  
      
    # Create training and evaluation subsets  
    train_dataset = Subset(dataset, train_indices)  
    eval_dataset = Subset(dataset, eval_indices)  
      
    return train_dataset, eval_dataset  



class ImageFolderSubsample(Dataset):  
    def __init__(self, root, samples_per_class, transform=None):  
        self.samples_per_class = samples_per_class  
        self.transform = transform  
  
        # Load the dataset using ImageFolder  
        self.dataset = ImageFolder(root)  
          
        # Organize the samples by class  
        self.class_to_samples = {}  
        for path, label in self.dataset.samples:  
            if label not in self.class_to_samples:  
                self.class_to_samples[label] = []  
            self.class_to_samples[label].append(path)  
        
        if self.samples_per_class is None:
            self.samples_per_class = len(self.class_to_samples[0])         
            
             
        # Sample the specified number of images per class  
        self.samples = []  
        for label, paths in self.class_to_samples.items():  
            if samples_per_class != -1:
                if len(paths) >= samples_per_class:
                    sampled_paths = random.sample(paths, samples_per_class)  
                elif len(paths) < samples_per_class:
                    sampled_paths = random.sample(paths, len(paths))            
                    print(f"Class {label} has less than {samples_per_class} samples. Using {len(sampled_paths)} samples.")         

            else:
                sampled_paths = paths
                print('no subsampling')
            

            self.samples.extend([(path, label) for path in sampled_paths])  
  
    def __len__(self):  
        return len(self.samples)  
  
    def __getitem__(self, index):  
        path, label = self.samples[index]  
        image = Image.open(path).convert('RGB')  
          
        if self.transform is not None:  
            image = self.transform(image)  
          
        return image, label  
    


class ImageFolderSubsampleTwins(Dataset):  
    def __init__(self, root, samples_per_class, transform=None):  
        self.samples_per_class = samples_per_class  
        self.transform = transform  
  
        # Load the dataset using ImageFolder  
        self.dataset = ImageFolder(root)  
          
        # Organize the samples by class  
        self.class_to_samples = {}  
        for path, label in self.dataset.samples:  
            if label not in self.class_to_samples:  
                self.class_to_samples[label] = []  
            self.class_to_samples[label].append(path)  
        
        if self.samples_per_class is None:
            self.samples_per_class = len(self.class_to_samples[0])         
            
             
        # Sample the specified number of images per class  
        self.samples = []  
        for label, paths in self.class_to_samples.items():  
            if samples_per_class != -1:
                if len(paths) >= samples_per_class:
                    sampled_paths = random.sample(paths, samples_per_class)  
                elif len(paths) < samples_per_class:
                    sampled_paths = random.sample(paths, len(paths))            
                    # print(f"Class {label} has less than {samples_per_class} samples. Using {len(sampled_paths)} samples.")         

            else:
                sampled_paths = paths
                # print('no subsampling')
            

            self.samples.extend([(path, label) for path in sampled_paths])  
  
    def __len__(self):  
        return len(self.samples)  
  
    def __getitem__(self, index):  
        path, label = self.samples[index]  
        image = Image.open(path).convert('RGB')  
        twin_idx = random.randint(0, len(self.class_to_samples[label])-1)     
        twin_path = self.class_to_samples[label][twin_idx]      
        twin_image = Image.open(twin_path).convert('RGB')       
          
        if self.transform is not None:  
            image = self.transform(image)  
            twin_image = self.transform(twin_image) 
            
          
        return image, twin_image, label 



class AugmentedFolderDataset(DatasetFolder): 
    def __init__(
        self,
        root_main: Union[str, Path],
        root_aug: Union[str, Path],
        loader: Callable[[str], Any] = default_loader,
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
        allow_empty: bool = False,
    ) -> None:
        # super().__init__(root_main, transform=transform, target_transform=target_transform, loader=loader)
        self.root_main = root_main  
        self.root_aug = root_aug        
        classes, class_to_idx = self.find_classes(self.root_main)
        classes_aug, class_to_idx_aug = self.find_classes(self.root_aug)    
        self.transform = transform     
        self.target_transform = target_transform         


        print(classes, class_to_idx)    
        samples_main = self.make_dataset(
            self.root_main,
            class_to_idx=class_to_idx,
            extensions=IMG_EXTENSIONS if is_valid_file is None else None,
            is_valid_file=is_valid_file,
            allow_empty=allow_empty,
        )


        samples_aug = self.make_dataset(
            self.root_aug,
            class_to_idx=class_to_idx_aug,
            extensions=IMG_EXTENSIONS if is_valid_file is None else None,
            is_valid_file=is_valid_file,
            allow_empty=allow_empty,
        )


        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples_main + samples_aug   
        self.targets = [s[1] for s in self.samples]

        self.imgs = self.samples




class SingleFolderDS(Dataset):  
    def __init__(self, folder_path, transform=None):  
        self.folder_path = folder_path  
        self.transform = transform  
        self.image_paths = [os.path.join(self.folder_path, f) for f in sorted(os.listdir(self.folder_path)) 
                            if os.path.isfile(os.path.join(self.folder_path, f)) and f.lower().endswith('.jpg') or f.lower().endswith('.png') or f.lower().endswith('.jpeg')]
        

    def __len__(self):  
        return len(self.image_paths)  
  
    def __getitem__(self, idx):  
        if torch.is_tensor(idx):  
            idx = idx.tolist()  
  
        image_path = self.image_paths[idx]  
        image = Image.open(image_path).convert('RGB')  
  
        if self.transform:  
            image = self.transform(image)  
  
        return image, image_path

def get_embedder(multires, i=0):
    if i == -1:
        return nn.Identity(), 3
    
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : 1,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim


class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


@torch.no_grad()    
def eval(model, dl):

    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in tqdm(enumerate(dl), total=len(dl)):                
    # for b_ind, (x, y) in enumerate(dl):
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1).cpu() == y.cpu()).sum().item()      
        num_samples += x.shape[0]       


    return crrct / num_samples * 100.

@torch.no_grad()    
def eval_mp(model, dl):

    device = f'cuda:{dist.get_rank()}' if torch.cuda.is_available() else 'cpu'
    model.to(device)        

    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in tqdm(enumerate(dl), total=len(dl)):                
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1).cpu() == y.cpu()).sum().item()      
        num_samples += x.shape[0]       


    return crrct / num_samples * 100.


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1.0 - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

class ClsFolder(torch.utils.data.Dataset):
    def __init__(self, cls_dir, cls_ind, mem=False, shuffle=False, transform=None):
        # super(ImageFolder, self).__init__()
        self.transform = transform  

        self.mem = mem
        self.image_paths = []
        self.targets = []
        self.samples = []
        
        file_ls = os.listdir(cls_dir)
        if shuffle:
            random.shuffle(file_ls)
        # print(len(file_ls))
        for i in range(len(file_ls)):   
            self.image_paths.append(cls_dir + "/" + file_ls[i])     
            self.targets.append(cls_ind)
            if self.mem:
                self.samples.append(Image.open(self.image_paths[i]).convert("RGB"))     

    def __getitem__(self, index):
        if self.mem:
            sample = self.samples[index]
        else:
            sample = Image.open(self.image_paths[index]).convert("RGB")     

        sample = self.transform(sample)
        return sample, self.targets[index]

    def __len__(self):
        return len(self.targets)


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

denormalize = transforms.Compose(
    [
        transforms.Normalize(
            mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
        ),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),
    ]
)



def get_normalize_trans(args):
    if args.subset.startswith("imagenet") or args.subset == 'tiny':     
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     
        denormalize = transforms.Compose([transforms.Normalize(mean=[0.0, 0.0, 0.0], 
            std=[1 / 0.229, 1 / 0.224, 1 / 0.225]), transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),])
        
    elif args.subset == 'cifar100' or args.subset == 'cifar10':     
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])       
        denormalize = transforms.Compose([transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1 / 0.2023, 1 / 0.1994, 1 / 0.2010]), 
            transforms.Normalize(mean=[-0.4914, -0.4822, -0.4465], std=[1.0, 1.0, 1.0]),])  

    return normalize, denormalize        

def set_dataset_specs(args):    
    if args.subset in ["imagenet-a", "imagenet-b", "imagenet-c", "imagenet-d",
                        "imagenet-e", "imagenet-birds", "imagenet-fruits", "imagenet-cats", 
                        "imagenet-10", "imagenette",'imagenet-woof']:
        
        args.nclass = 10
        if args.input_size == None:
            args.input_size = 224
        
        if args.init_resize == None:
            args.init_resize = 256      

    elif args.subset == 'imagenet':  
        args.nclass = 1000
        args.input_size = 224
        args.init_resize = 256  

    elif args.subset == 'imagenet100':  
        args.nclass = 100
        args.input_size = 224
        args.init_resize = 256      
        
    elif args.subset == 'tiny':
        args.nclass = 200
        args.input_size = 64
        args.init_resize = 64      

    elif args.subset =='cifar100':
        args.nclass = 100       
        args.input_size = 32
        args.init_resize = 32
        
    elif args.subset == 'cifar10':
        args.nclass = 10       
        args.input_size = 32
        args.init_resize = 32

    args.classes = range(args.nclass)       
    args.val_ipc = 50
    if args.end_cls is None:    
        args.end_cls = args.nclass


def get_network(args, arch, pretrained=True, data_parallel=True):   
    def pruning_classifier(model=None, classes=[]):
        try:
            model_named_parameters = [name for name, x in model.named_parameters()]
            for name, x in model.named_parameters():
                if (
                    name == model_named_parameters[-1]
                    or name == model_named_parameters[-2]
                ):
                    x.data = x[classes]

            print("Changed the number of classes.")
        except:
            print("ERROR in changing the number of classes.")

        return model


    if arch == 'resnet18':
        if args.subset == 'imagenet':
            from torchvision.models import resnet18     
            print("Loading pretrained resnet18")        
            model = resnet18(pretrained=pretrained)       

        elif args.subset.startswith("imagenet"):        
            if args.force_rded_net == False:
                model = ResNet18(args.nclass)   
                print('force rded net is False')    
            else:
                model = models.__dict__[arch](pretrained=pretrained)
                classes = [i for i in range(0, args.nclass)]        
                model = pruning_classifier(model, classes)
                print('force rded net is True') 


        elif args.subset == 'tiny':
            from torchvision.models import resnet18     
            model = resnet18(num_classes=200)
            model.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
            model.maxpool = nn.Identity()        

    
            
    if arch == 'resnet50' and args.subset.startswith("imagenet"):  
        model = ResNet50(args.nclass)       

    if args.subset == 'imagenet' and arch.startswith('resnet') == False:
        model = models.__dict__[arch](pretrained=pretrained)

        print(f'model: {arch} loaded from torchvision')     

    if (args.subset == 'cifar10' or args.subset == 'cifar100')   and arch.startswith('resnet') == True:
        model = models.__dict__[arch](pretrained=False)
        model.conv1 = nn.Conv2d(
            3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        )
        model.maxpool = nn.Identity()
        classes = [i for i in range(0, args.nclass)]        
        model = pruning_classifier(model, classes)
        

    if arch == 'convnet-3':       
        model = ConvNet(num_classes=args.nclass, net_depth=3, net_width=128, 
                        im_size=(args.input_size, args.input_size))
        
    elif arch == 'convnet-4':       
        model = ConvNet(num_classes=args.nclass, net_depth=4, net_width=128, 
                        im_size=(args.input_size, args.input_size))
        
    elif arch == 'convnet-5':       
        model = ConvNet(num_classes=args.nclass, net_depth=5, net_width=128, 
                        im_size=(args.input_size, args.input_size))
    

    if arch == 'alexnet':
        if args.subset.startswith('cifar') == False:
            model = AlexNet(3, args.nclass, (args.input_size, args.input_size))
        else:
            model = AlexNetCIFAR(3, args.nclass)


    if arch == 'vgg11':     
        if args.subset.startswith('cifar') == False:
            model = VGG11(3, args.nclass)
        else:
            model = VGG11CIFAR(3, args.nclass)
            print('VGG CIFAR is loaded ##########')


    if arch == 'vit':
        model = ViT(image_size = (args.input_size, args.input_size), patch_size = 16, num_classes = args.nclass,
                     dim = 512, depth = 10, heads = 8, mlp_dim = 512, dropout = 0.1, emb_dropout = 0.1,)


    if pretrained==True and args.model_ckpt is not None:  

        print(f"###### Loading model from {args.model_ckpt} ############" )         
        state_dict = torch.load(args.model_ckpt)    
       
        # Remove 'module' prefix if necessary
        if next(iter(state_dict)).startswith('module.'):
            # Create new state dict without 'module' prefix
            new_state_dict = {}
            for k, v in state_dict.items():
                new_state_dict[k[len('module.'):]] = v
            state_dict = new_state_dict

        if 'model' in state_dict:
            state_dict = state_dict['model']    

        model.load_state_dict(state_dict)       
        print(f"###### Loaded model from {args.model_ckpt} ############" )

    if data_parallel:   
        model = torch.nn.DataParallel(model)
    
    model = model.to(device)        
    return model

def get_imagenet_classes(args):
    if args.subset == 'imagenette':
        classes = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]
    elif  args.subset == 'imagenet-woof':
        classes = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229]
    elif args.subset == 'imagenet':      
        classes = [i for i in range(0, 1000)]     
    elif args.subset == 'imagenet100':     
        classes = [15, 45, 54, 57, 64, 74, 90, 99, 119, 120, 122, 131, 137, 151, 155, 157, 158, 166, 167, 169,
                    176, 180, 209, 211, 222, 228, 234, 236, 242, 246, 267, 268, 272, 275, 277, 281, 299, 305, 
                    313, 317, 331, 342, 368, 374, 407, 421, 431, 449, 452, 455, 479, 494, 498, 503, 508, 544, 
                    560, 570, 592, 593, 599, 606, 608, 619, 620, 653, 659, 662, 665, 667, 674, 682, 703, 708, 
                    717, 724, 748, 758, 765, 766, 772, 775, 796, 798, 830, 854, 857, 858, 872, 876, 882, 904, 
                    908, 936, 938, 953, 959, 960, 993, 994]         

    return classes  


def get_ds_for_cls(args, cls_ind, rnd_crop=True):

    root_dir = os.path.join(args.root_dir, args.subset)     
    train_dir = os.path.join(root_dir, "train")     

    cls_list = sorted([f for f in os.listdir(train_dir) if f.startswith('.') == False])           
    cls = cls_list[cls_ind] 
    cls_dir = os.path.join(train_dir, cls)  

    aug = [transforms.ToTensor()]
    if rnd_crop:
        aug.append(MultiRandomCrop(num_crop=args.num_crop, size=args.diff_input_size, factor=args.factor))
    
    aug.append(normalize)       

    trans = transforms.Compose(aug)

    train_ds = ClsFolder(cls_dir, cls_ind, shuffle=True, transform=trans, mem=True)
    
    return train_ds, cls


def get_ds_for_cls_collage_chopping(args, cls_ind):
    cls_list = sorted([f for f in os.listdir(args.collage_save_dir) if f.startswith('.') == False])        
    cls = cls_list[cls_ind] 
    cls_dir = os.path.join(args.collage_save_dir, cls)  

    trans = transforms.Compose([transforms.ToTensor(), ChopCollageTrans(args.diff_input_size, args.factor), normalize,])

    train_ds = ClsFolder(cls_dir, cls_ind, shuffle=False, transform=trans, mem=True)
    
    return train_ds, cls


def get_ds_for_cls_resizing(args, cls_ind):
    cls_list = sorted([f for f in os.listdir(args.collage_save_dir) if f.startswith('.') == False])        
    cls = cls_list[cls_ind] 
    cls_dir = os.path.join(args.collage_save_dir, cls)  

    trans = transforms.Compose([transforms.ToTensor(), normalize,])

    train_ds = ClsFolder(cls_dir, cls_ind, shuffle=False, transform=trans, mem=True)
    
    return train_ds, cls


def get_dataset(args, train_trans, eval_ratio=0.0):
    normalize, denormalize = get_normalize_trans(args)          

    class_name_dic = torch.load(args.dataset_name_dict)     
    sorted_keys = sorted(list(class_name_dic.keys()))       
    folder_names = sorted_keys.copy()       
    class_names = [class_name_dic[sorted_keys[c]] for c in range(len(sorted_keys))]        
    
    if args.subset.startswith("imagenet") or args.subset == 'tiny': 
        val_trans = transforms.Compose([transforms.Resize((args.init_resize, args.init_resize)), 
                                        transforms.CenterCrop(args.input_size), transforms.ToTensor(), normalize])      
        
        if args.subset.startswith("imagenet"):
            root_dir = os.path.join(args.root_dir, 'imagenet')     
            train_dir = os.path.join(root_dir, "train")     
            val_dir = os.path.join(root_dir, "val")     
            val_ds = torchvision.datasets.ImageFolder(val_dir, transform=val_trans)             
            classes = get_imagenet_classes(args)        

        elif args.subset == 'tiny':
            root_dir = os.path.join(args.root_dir, 'tiny-imagenet-200')     
            train_dir = os.path.join(root_dir, "train")     
            val_dir = os.path.join(root_dir, "val", "images")     
            classes = [i for i in range(0, 200)]     

        train_ds = torchvision.datasets.ImageFolder(train_dir, transform=train_trans)     
        train_ds_no_aug = torchvision.datasets.ImageFolder(train_dir, transform=val_trans)    
        val_ds = torchvision.datasets.ImageFolder(val_dir, transform=val_trans)
        
        
    elif args.subset == 'cifar100':
        val_trans = transforms.Compose([transforms.ToTensor(), normalize])          
        if train_trans is None:
            train_trans = val_trans

        train_ds = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_trans)
        val_ds = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=val_trans)     
        classes = [i for i in range(0, 100)]            

    
    elif args.subset == 'cifar10':    
        val_trans = transforms.Compose([transforms.ToTensor(), normalize])          
        if train_trans is None:
            train_trans = val_trans  

        train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_trans)
        val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_trans)     
        classes = [i for i in range(0, 10)]     


    dst_train = torch.utils.data.Subset(train_ds, np.squeeze(np.argwhere(np.isin(train_ds.targets, classes))))
    dst_train_no_aug = torch.utils.data.Subset(train_ds_no_aug, np.squeeze(np.argwhere(np.isin(train_ds_no_aug.targets, classes))))         

    dst_test = torch.utils.data.Subset(val_ds, np.squeeze(np.argwhere(np.isin(val_ds.targets, classes))))

    tar_trans_dic = {}
    for c in range(len(classes)):
        tar_trans_dic[classes[c]] = c       

    tar_trans = lambda x: tar_trans_dic[x]          
        
    dst_test.dataset.target_transform = tar_trans
    dst_train.dataset.target_transform  = tar_trans    
    dst_train_no_aug.dataset.target_transform= tar_trans    
    

    folder_names = [sorted_keys[c] for c in classes]
    class_names = [class_name_dic[sorted_keys[c]] for c in classes]  

    dst_eval = None
    if eval_ratio > 0:
        split_rnd_gen = torch.Generator().manual_seed(42)           
        num_eval = int(len(dst_train) * eval_ratio)   
        num_train = len(dst_train) - num_eval        

        train_ind, val_ind = random_split(range(len(dst_train)), [num_train, num_eval], generator=split_rnd_gen)         
        train_ind = [i for i in train_ind.indices]
        val_ind = [i for i in val_ind.indices]      


        if args.augment:    
            print(f'pre augmented train: {len(train_ind)}')
            pre_aug_size = len(dst_train)       
            dst_train, added_img_num = add_folder_to_dataset(args, dst_train)      
            train_ind += list(range(pre_aug_size, pre_aug_size + added_img_num))        

        dst_train = torch.utils.data.Subset(dst_train, train_ind)    
        dst_eval = torch.utils.data.Subset(dst_train_no_aug, val_ind)       
        print(f"Train: {len(dst_train)}, Eval: {len(dst_eval)}")    
        
    return dst_train, dst_test, dst_eval, folder_names, class_names, classes


def get_folder_cls_names(args, subset):     
    class_name_dic = torch.load(args.dataset_name_dict)     
    sorted_keys = sorted(list(class_name_dic.keys()))       
    folder_names = sorted_keys.copy()       
    class_names = [class_name_dic[sorted_keys[c]] for c in range(len(sorted_keys))]     

    if subset is True:     
        classes = get_imagenet_classes(args)        
        folder_names = [sorted_keys[c] for c in classes]
        class_names = [class_name_dic[sorted_keys[c]] for c in classes]


    return folder_names, class_names    


def add_folder_to_dataset(args, dataset):
    class_names = sorted([f for f in os.listdir(args.aug_root) if f.startswith('.') == False])            

    if args.subset.startswith("imagenet"):
        classes = get_imagenet_classes(args)        
        print(f'classes: {classes}')    
    else:
        pass 
    
    added_img_cnt = 0   
    prev_size = len(dataset.dataset.targets)        

    for cls_ind, class_names in enumerate(class_names):    
        cls_dir = os.path.join(args.aug_root, class_names)     
        image_addresses = [os.path.join(cls_dir, f) for f in os.listdir(cls_dir) if f.startswith('.') == False]         
        for img in image_addresses:    
            dataset.dataset.samples.append((img, classes[cls_ind]))           
            dataset.dataset.targets.append(classes[cls_ind])
            added_img_cnt += 1      

    new_data_indices = [i for i in range(prev_size, prev_size + added_img_cnt)]     
    dataset.indices = np.concatenate((dataset.indices, new_data_indices))           
    
    return dataset, added_img_cnt


def get_syn_ds(args, folder_name):
    train_trans = transforms.Compose([transforms.ToTensor(), normalize])
    train_ds = torchvision.datasets.ImageFolder(folder_name, transform=train_trans)         
    return train_ds     

def get_batch(data_loader, dl_iter):
    try:
        x, y = next(dl_iter)
    except StopIteration:
        dl_iter = iter(data_loader)
        x, y = next(dl_iter)

    
    return x, y, dl_iter     
   

def tensor_to_img(args, x):
    __, denormalize = get_normalize_trans(args)
    img_np = denormalize(x).squeeze().permute(1, 2, 0).cpu().numpy()  
    img_np = (img_np * 255).astype(np.uint8)    
    img = Image.fromarray(img_np)   
    return img


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


def cutmix(images):
    rand_index = torch.randperm(images.size()[0]).cuda()
    lam = np.random.beta(1, 1)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)

    images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
    return images


def create_token_names(args, save=False):
    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
    if args.subset.startswith("imagenet"):
        classes = get_imagenet_classes(args)            
        code_name_dic = torch.load(args.dataset_name_dict)       
        sorted_keys = sorted(code_name_dic.keys())      
              

    elif args.subset == 'tiny':
        code_name_dic = torch.load(args.dataset_name_dict)       
        classes = [i for i in range(0, 200)]        
        sorted_keys = sorted(code_name_dic.keys())      

    elif args.subset == 'cifar100': 
        code_name_dic = torch.load(args.dataset_name_dict)       
        classes = [i for i in range(0, 100)]        
        sorted_keys = sorted(code_name_dic.keys())      
    elif args.subset == 'cifar10':      
        code_name_dic = torch.load(args.dataset_name_dict)       
        classes = [i for i in range(0, 10)]        
        sorted_keys = sorted(code_name_dic.keys())

    
    place_holder_names = [f'<{code_name_dic[sorted_keys[c]]}>' for c in classes]    

    initialzier_names = [code_name_dic[sorted_keys[c]].split('_')[-1] for c in classes]
    for t_ind, init_name in enumerate(initialzier_names):   
        token_ids = tokenizer.encode(init_name, add_special_tokens=False)
        if len(token_ids) > 1:
            initialzier_names[t_ind] = 'photo'

    if save:    
        with open(f"{args.subset}_place_holder_names.txt", "w") as f:
            for name in place_holder_names:
                f.write(name + "\n")

        with open(f"{args.subset}_initializer_names.txt", "w") as f:
            for name in initialzier_names:
                    f.write(name + "\n")  


    return place_holder_names, initialzier_names 


def create_tiny_val_img_folder(root_dir):
    
    
    dataset_dir = os.path.join(root_dir, 'tiny-imagenet-200')
    val_dir = os.path.join(dataset_dir, 'val')
    img_dir = os.path.join(val_dir, 'images')

    fp = open(os.path.join(val_dir, 'val_annotations.txt'), 'r')
    data = fp.readlines()
    val_img_dict = {}
    for line in data:
        words = line.split('\t')
        val_img_dict[words[0]] = words[1]
    fp.close()

    # Create folder if not present and move images into proper folders
    for img, folder in val_img_dict.items():
        newpath = (os.path.join(img_dir, folder))
        if not os.path.exists(newpath):
            os.makedirs(newpath)
        if os.path.exists(os.path.join(img_dir, img)):
            os.rename(os.path.join(img_dir, img), os.path.join(newpath, img))


def download_and_save_cifar(dataset_name='cifar10', root_dir='./data'):
    
    assert dataset_name in ['cifar10', 'cifar100'], "Dataset must be 'cifar10' or 'cifar100'"

    # Choose dataset
    if dataset_name == 'cifar10':
        dataset = torchvision.datasets.CIFAR10(root=root_dir, train=True, download=True)
        test_dataset = torchvision.datasets.CIFAR10(root=root_dir, train=False, download=True)
    else:
        dataset = torchvision.datasets.CIFAR100(root=root_dir, train=True, download=True)
        test_dataset = torchvision.datasets.CIFAR100(root=root_dir, train=False, download=True)

    # Create directories
    train_dir = os.path.join(root_dir, dataset_name, 'train')
    test_dir = os.path.join(root_dir, dataset_name, 'test')
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    # Define save function
    def save_images(data, dir_path):
        for idx, (img, label) in enumerate(data):
            label_dir = os.path.join(dir_path, f'cls_{str(label).zfill(3)}')    
            os.makedirs(label_dir, exist_ok=True)   
            # img = transforms.ToPILImage()(img)  # Convert tensor to PIL image
            img_path = os.path.join(label_dir, f"{idx}.jpg")
            img.save(img_path, "JPEG")
    
    # Save train and test images
    print("Saving training images...")
    save_images(dataset, train_dir)
    print("Saving test images...")
    save_images(test_dataset, test_dir)

    print(f"Dataset saved in {root_dir}/{dataset_name}")