import torch 
import torchvision.transforms as transforms 
import os
import numpy as np  
import torchvision  
import random   
from PIL import Image   
from resnet import *        
from transformers import CLIPTokenizer
import argparse 
import torchvision
from models import ConvNet
from tqdm import tqdm       
import torchvision.models as models

# nltk.download('wordnet')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   

class ShufflePatches(torch.nn.Module):
    def shuffle_weight(self, img, factor):
        h, w = img.shape[1:]
        th, tw = h // factor, w // factor
        patches = []
        for i in range(factor):
            i = i * tw
            if i != factor - 1:
                patches.append(img[..., i : i + tw])
            else:
                patches.append(img[..., i:])
        random.shuffle(patches)
        img = torch.cat(patches, -1)
        return img

    def __init__(self, factor):
        super().__init__()
        self.factor = factor

    def forward(self, img):
        img = self.shuffle_weight(img, self.factor)
        img = img.permute(0, 2, 1)
        img = self.shuffle_weight(img, self.factor)
        img = img.permute(0, 2, 1)
        return img
    

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1.0 - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

    

def cutmix(images, cutmix):
    rand_index = torch.randperm(images.size()[0]).cuda()
    lam = np.random.beta(cutmix, cutmix)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)

    images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
    return images, rand_index.cpu(), lam, [bbx1, bby1, bbx2, bby2]


class ClsFolder(torch.utils.data.Dataset):
    def __init__(self, cls_dir, cls_ind, mem=False, shuffle=False, transform=None):
        # super(ImageFolder, self).__init__()
        self.transform = transform  

        self.mem = mem
        self.image_paths = []
        self.targets = []
        self.samples = []
        
        file_ls = os.listdir(cls_dir)
        if shuffle:
            random.shuffle(file_ls)
        # print(len(file_ls))
        for i in range(len(file_ls)):   
            self.image_paths.append(cls_dir + "/" + file_ls[i])     
            self.targets.append(cls_ind)
            if self.mem:
                self.samples.append(Image.open(self.image_paths[i]).convert("RGB"))     

    def __getitem__(self, index):
        if self.mem:
            sample = self.samples[index]
        else:
            sample = Image.open(self.image_paths[index]).convert("RGB")     

        sample = self.transform(sample)
        return sample, self.targets[index]

    def __len__(self):
        return len(self.targets)
    


class MultiRandomCrop(torch.nn.Module):
    def __init__(self, num_crop=5, size=224, factor=2):
        super().__init__()
        self.num_crop = num_crop
        self.size = size
        self.factor = factor
        self.resize_trans = transforms.Resize(((self.size // self.factor), (self.size // self.factor)))       
        self.cropper = transforms.RandomResizedCrop(
            self.size // self.factor,
            ratio=(1, 1),
            antialias=True,
        )

    def forward(self, image):
        patches = []
        if self.num_crop > 0:
            for _ in range(self.num_crop):
                patches.append(self.cropper(image))

            imgs = torch.stack(patches, 0)
        else:
            
            imgs = self.resize_trans(image).unsqueeze(0)       

        return imgs

    def __repr__(self) -> str:
        detail = f"(num_crop={self.num_crop}, size={self.size})"
        return f"{self.__class__.__name__}{detail}"
    


class ChopCollageTrans(torch.nn.Module):
    def __init__(self, size, factor):
        super().__init__()
        self.factor = factor
        self.bbs = [(i, j, size//factor, size//factor) 
               for i in range(0, size, size // factor) for j in range(0, size, size // factor)]
                

    def forward(self, image):
        image = image.unsqueeze(0)      
        b, c, h, w = image.shape        
        
        patches = [torchvision.transforms.functional.crop(image, *bb) for bb in self.bbs]       
        patches = torch.stack(patches, 0)       
    
        return patches   

    def __repr__(self) -> str:
        detail = f"(num_crop={self.num_crop}, size={self.size})"
        return f"{self.__class__.__name__}{detail}"
    

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

denormalize = transforms.Compose(
    [
        transforms.Normalize(
            mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
        ),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),
    ]
)



def get_normalize_trans(args):
    if args.subset.startswith("imagenet") or args.subset == 'tiny':     
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     
        denormalize = transforms.Compose([transforms.Normalize(mean=[0.0, 0.0, 0.0], 
            std=[1 / 0.229, 1 / 0.224, 1 / 0.225]), transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),])
        
    elif args.subset == 'cifar100' or args.subset == 'cifar10':     
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])       
        denormalize = transforms.Compose([transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1 / 0.2023, 1 / 0.1994, 1 / 0.2010]), 
            transforms.Normalize(mean=[-0.4914, -0.4822, -0.4465], std=[1.0, 1.0, 1.0]),])  

    return normalize, denormalize        

def set_dataset_specs(args):    
    if args.subset in ["imagenet-a", "imagenet-b", "imagenet-c", "imagenet-d",
                        "imagenet-e", "imagenet-birds", "imagenet-fruits", "imagenet-cats", 
                        "imagenet-10", "imagenette",'imagenet-woof']:
        
        args.nclass = 10
        args.input_size = 224
        args.init_resize = 256      

    elif args.subset == 'imagenet':  
        args.nclass = 1000
        args.input_size = 224
        args.init_resize = 256  

    elif args.subset == 'imagenet100':  
        args.nclass = 100
        args.input_size = 224
        args.init_resize = 256      
        
    elif args.subset == 'tiny':
        args.nclass = 200
        args.input_size = 64
        args.init_resize = 64      

    elif args.subset =='cifar100':
        args.nclass = 100       
        args.input_size = 32
        args.init_resize = 32
        
    elif args.subset == 'cifar10':
        args.nclass = 10       
        args.input_size = 32
        args.init_resize = 32

    args.classes = range(args.nclass)       
    args.val_ipc = 50
    if args.end_cls is None:    
        args.end_cls = args.nclass


def get_network(args, pretrained=True):      
    if args.arch == 'resnet18':
        if args.subset == 'imagenet':
            from torchvision.models import resnet18     
            print("Loading pretrained resnet18")        
            model = resnet18(pretrained=pretrained)       

        elif args.subset.startswith("imagenet"):        
            model = ResNet18(args.nclass)   

        elif args.subset == 'tiny':
            from torchvision.models import resnet18     
            model = resnet18(num_classes=200)
            model.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
            model.maxpool = nn.Identity()        

    elif args.arch == 'convnet-3':       
        model = ConvNet(num_classes=args.nclass, net_depth=3, net_width=128, 
                        im_size=(args.input_size, args.input_size))
        
    elif args.arch == 'convnet-4':       
        model = ConvNet(num_classes=args.nclass, net_depth=4, net_width=128, 
                        im_size=(args.input_size, args.input_size))
    elif args.subset == 'imagenet':
        model = models.__dict__[args.arch](pretrained=pretrained)

        print(f'model: {args.arch} loaded from torchvision')     


    model = torch.nn.DataParallel(model)

    if pretrained==True and args.model_ckpt is not None:      
        model.load_state_dict(torch.load(args.model_ckpt))
        print(f"###### Loaded model from {args.model_ckpt} ############" )       

        

    model = model.to(device)        
    return model



def get_imagenet_classes(args):
    if args.subset == 'imagenette':
        classes = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]
    elif  args.subset == 'imagenet-woof':
        classes = [193, 182, 258, 162, 155, 167, 159, 273, 207, 229]
    elif args.subset == 'imagenet':      
        classes = [i for i in range(0, 1000)]     
    elif args.subset == 'imagenet100':     
        classes = [15, 45, 54, 57, 64, 74, 90, 99, 119, 120, 122, 131, 137, 151, 155, 157, 158, 166, 167, 169,
                    176, 180, 209, 211, 222, 228, 234, 236, 242, 246, 267, 268, 272, 275, 277, 281, 299, 305, 
                    313, 317, 331, 342, 368, 374, 407, 421, 431, 449, 452, 455, 479, 494, 498, 503, 508, 544, 
                    560, 570, 592, 593, 599, 606, 608, 619, 620, 653, 659, 662, 665, 667, 674, 682, 703, 708, 
                    717, 724, 748, 758, 765, 766, 772, 775, 796, 798, 830, 854, 857, 858, 872, 876, 882, 904, 
                    908, 936, 938, 953, 959, 960, 993, 994]         

    return classes  


def get_ds_for_cls(args, cls_ind, rnd_crop=True):

    root_dir = os.path.join(args.root_dir, args.subset)     
    train_dir = os.path.join(root_dir, "train")     

    cls_list = sorted([f for f in os.listdir(train_dir) if f.startswith('.') == False])           
    cls = cls_list[cls_ind] 
    cls_dir = os.path.join(train_dir, cls)  

    aug = [transforms.ToTensor()]
    if rnd_crop:
        aug.append(MultiRandomCrop(num_crop=args.num_crop, size=args.diff_input_size, factor=args.factor))
    
    aug.append(normalize)       

    trans = transforms.Compose(aug)

    train_ds = ClsFolder(cls_dir, cls_ind, shuffle=True, transform=trans, mem=True)
    
    return train_ds, cls


def get_ds_for_cls_collage_chopping(args, cls_ind):
    cls_list = sorted([f for f in os.listdir(args.collage_save_dir) if f.startswith('.') == False])        
    cls = cls_list[cls_ind] 
    cls_dir = os.path.join(args.collage_save_dir, cls)  

    trans = transforms.Compose([transforms.ToTensor(), ChopCollageTrans(args.diff_input_size, args.factor), normalize,])

    train_ds = ClsFolder(cls_dir, cls_ind, shuffle=False, transform=trans, mem=True)
    
    return train_ds, cls


def get_ds_for_cls_resizing(args, cls_ind):
    cls_list = sorted([f for f in os.listdir(args.collage_save_dir) if f.startswith('.') == False])        
    cls = cls_list[cls_ind] 
    cls_dir = os.path.join(args.collage_save_dir, cls)  

    trans = transforms.Compose([transforms.ToTensor(), normalize,])

    train_ds = ClsFolder(cls_dir, cls_ind, shuffle=False, transform=trans, mem=True)
    
    return train_ds, cls


def get_dataset(args, train_trans):
    normalize, denormalize = get_normalize_trans(args)          

    class_name_dic = torch.load(args.dataset_name_dict)     
    sorted_keys = sorted(list(class_name_dic.keys()))       
    folder_names = sorted_keys.copy()       
    class_names = [class_name_dic[sorted_keys[c]] for c in range(len(sorted_keys))]        
    
    if args.subset.startswith("imagenet") or args.subset == 'tiny': 
        val_trans = transforms.Compose([transforms.Resize((args.init_resize, args.init_resize)), 
                                        transforms.CenterCrop(args.input_size), transforms.ToTensor(), normalize])      
        
        if args.subset.startswith("imagenet"):
            root_dir = os.path.join(args.root_dir, 'imagenet')     
            train_dir = os.path.join(root_dir, "train")     
            val_dir = os.path.join(root_dir, "val")     
            train_ds = torchvision.datasets.ImageFolder(train_dir, transform=train_trans)     
            val_ds = torchvision.datasets.ImageFolder(val_dir, transform=val_trans)             
            classes = get_imagenet_classes(args)        

        elif args.subset == 'tiny':
            root_dir = os.path.join(args.root_dir, 'tiny-imagenet-200')     
            train_dir = os.path.join(root_dir, "train")     
            val_dir = os.path.join(root_dir, "val", "images")     
            classes = [i for i in range(0, 200)]     

        train_ds = torchvision.datasets.ImageFolder(train_dir, transform=train_trans)     
        val_ds = torchvision.datasets.ImageFolder(val_dir, transform=val_trans)
        
        
    elif args.subset == 'cifar100':
        val_trans = transforms.Compose([transforms.ToTensor(), normalize])          
        if train_trans is None:
            train_trans = val_trans

        train_ds = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_trans)
        val_ds = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=val_trans)     
        classes = [i for i in range(0, 100)]            

    
    elif args.subset == 'cifar10':    
        val_trans = transforms.Compose([transforms.ToTensor(), normalize])          
        if train_trans is None:
            train_trans = val_trans  

        train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_trans)
        val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_trans)     
        classes = [i for i in range(0, 10)]     


    dst_train = torch.utils.data.Subset(train_ds, np.squeeze(np.argwhere(np.isin(train_ds.targets, classes))))
    dst_test = torch.utils.data.Subset(val_ds, np.squeeze(np.argwhere(np.isin(val_ds.targets, classes))))

    tar_trans_dic = {}
    for c in range(len(classes)):
        tar_trans_dic[classes[c]] = c       

    tar_trans = lambda x: tar_trans_dic[x]          
        
    dst_test.dataset.target_transform = tar_trans
    dst_train.dataset.target_transform  = tar_trans    

    train_ds = dst_train    
    val_ds = dst_test       

    folder_names = [sorted_keys[c] for c in classes]
    class_names = [class_name_dic[sorted_keys[c]] for c in classes]  
    
    return train_ds, val_ds, folder_names, class_names, classes


def get_folder_cls_names(args, subset):     
    class_name_dic = torch.load(args.dataset_name_dict)     
    sorted_keys = sorted(list(class_name_dic.keys()))       
    folder_names = sorted_keys.copy()       
    class_names = [class_name_dic[sorted_keys[c]] for c in range(len(sorted_keys))]     

    if subset is True:     
        classes = get_imagenet_classes(args)        
        folder_names = [sorted_keys[c] for c in classes]
        class_names = [class_name_dic[sorted_keys[c]] for c in classes]


    return folder_names, class_names    


def get_syn_ds(args, folder_name):
    

    # train_trans = transforms.Compose([transforms.RandomResizedCrop(args.input_size), 
    #                             transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
    
    train_trans = transforms.Compose([transforms.ToTensor(), normalize])

    train_ds = torchvision.datasets.ImageFolder(folder_name, transform=train_trans)  

       
    return train_ds     


def get_rded_syn_ds(args):
    root_dir = args.collage_save_dir

    augment = []
    augment.append(transforms.ToTensor())
    augment.append(transforms.Resize((args.input_size, args.input_size )))
    augment.append(normalize)

    train_ds = torchvision.datasets.ImageFolder(root_dir, transform=transforms.Compose(augment))            

    return train_ds     


def get_batch(data_loader, dl_iter):
    try:
        x, y = next(dl_iter)
    except StopIteration:
        dl_iter = iter(data_loader)
        x, y = next(dl_iter)

    
    return x, y, dl_iter     
   

def tensor_to_img(args, x):
    __, denormalize = get_normalize_trans(args)
    img_np = denormalize(x).squeeze().permute(1, 2, 0).cpu().numpy()  
    img_np = (img_np * 255).astype(np.uint8)    
    img = Image.fromarray(img_np)   
    return img


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


def cutmix(images):
    rand_index = torch.randperm(images.size()[0]).cuda()
    lam = np.random.beta(1, 1)
    bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)

    images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
    return images


def create_token_names(args, save=False):
    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
    if args.subset.startswith("imagenet"):
        classes = get_imagenet_classes(args)            
        code_name_dic = torch.load(args.dataset_name_dict)       
        sorted_keys = sorted(code_name_dic.keys())      
              

    elif args.subset == 'tiny':
        code_name_dic = torch.load(args.dataset_name_dict)       
        classes = [i for i in range(0, 200)]        
        sorted_keys = sorted(code_name_dic.keys())      

    elif args.subset == 'cifar100': 
        code_name_dic = torch.load(args.dataset_name_dict)       
        classes = [i for i in range(0, 100)]        
        sorted_keys = sorted(code_name_dic.keys())      
    elif args.subset == 'cifar10':      
        code_name_dic = torch.load(args.dataset_name_dict)       
        classes = [i for i in range(0, 10)]        
        sorted_keys = sorted(code_name_dic.keys())

    
    place_holder_names = [f'<{code_name_dic[sorted_keys[c]]}>' for c in classes]    

    initialzier_names = [code_name_dic[sorted_keys[c]].split('_')[-1] for c in classes]
    for t_ind, init_name in enumerate(initialzier_names):   
        token_ids = tokenizer.encode(init_name, add_special_tokens=False)
        if len(token_ids) > 1:
            initialzier_names[t_ind] = 'photo'

    if save:    
        with open(f"{args.subset}_place_holder_names.txt", "w") as f:
            for name in place_holder_names:
                f.write(name + "\n")

        with open(f"{args.subset}_initializer_names.txt", "w") as f:
            for name in initialzier_names:
                    f.write(name + "\n")  


    return place_holder_names, initialzier_names 


def create_tiny_val_img_folder(root_dir='/home/public/'):
    '''
    This method is responsible for separating validation images into separate sub folders
    '''
    
    dataset_dir = os.path.join(root_dir, 'tiny-imagenet-200')
    val_dir = os.path.join(dataset_dir, 'val')
    img_dir = os.path.join(val_dir, 'images')

    fp = open(os.path.join(val_dir, 'val_annotations.txt'), 'r')
    data = fp.readlines()
    val_img_dict = {}
    for line in data:
        words = line.split('\t')
        val_img_dict[words[0]] = words[1]
    fp.close()

    # Create folder if not present and move images into proper folders
    for img, folder in val_img_dict.items():
        newpath = (os.path.join(img_dir, folder))
        if not os.path.exists(newpath):
            os.makedirs(newpath)
        if os.path.exists(os.path.join(img_dir, img)):
            os.rename(os.path.join(img_dir, img), os.path.join(newpath, img))


def seperate_ipcs(syn_root_folder, save_root, ipc):
    class_names = sorted([f for f in os.listdir(syn_root_folder) if 'n' in f])       
    os.makedirs(save_root, exist_ok=True)       


    for class_name in tqdm(class_names, total=len(class_names)):            
        img_names = os.listdir(os.path.join(syn_root_folder, class_name))
        # img_names = [f'{i}.jpg' for i in range(len(img_names))]
        os.makedirs(os.path.join(save_root, class_name), exist_ok=True)

        for img_ind in range(ipc):
            img_add = os.path.join(syn_root_folder, class_name, img_names[img_ind])         
            

            img = Image.open(img_add)       
            img.save(os.path.join(save_root, class_name, img_names[img_ind]))
            

if __name__ == "__main__":  
    # home = os.path.expanduser("~")      
    # tiny_root = os.path.join(home, 'data') 
    # create_tiny_val_img_folder(root_dir=tiny_root)    

    # parser = argparse.ArgumentParser()      
    # parser.add_argument("--subset", type=str, default="imagenette") 
    # parser.add_argument("--dataset_name_dict", type=str, default="./dataset_dicts/imagenet_dic.pth") 
    # parser.add_argument("--pretrained_model_name_or_path", type=str, default="runwayml/stable-diffusion-v1-5") 
    # args = parser.parse_args()      
    # create_token_names(args, save=True)    


    parser = argparse.ArgumentParser()      
    argparse.ArgumentParser()   
    parser.add_argument('--syn_root_folder', type=str)
    parser.add_argument('--save_root', type=str)        
    parser.add_argument('--ipc', type=int)      
    args = parser.parse_args()      

    seperate_ipcs(args.syn_root_folder, args.save_root, args.ipc)
    pass