import torch 
import argparse
import numpy as np  
from utils  import *
from tqdm import tqdm       
from torchvision import transforms      
from torchvision.datasets import ImageFolder        

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")       

@torch.no_grad()        
def create_collage_ds_for_dataset(args):
    bs = args.factor ** 2
    # bs = 1
    
    assert args.model_ckpt is not None, "model_ckpt should be provided"     
    model = get_network(args, pretrained=True)     


    model.eval()
    loss_ce = torch.nn.CrossEntropyLoss(reduction='none')   

    
    resize_trans = transforms.Resize(args.input_size)     
    resize_trans_collage = transforms.Resize(args.diff_input_size // args.factor) 
    

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass" 
    
    if args.end_cls is None:
        args.end_cls = args.nclass


    normalize_trans, __ = get_normalize_trans(args)       

    train_trans = transforms.Compose([transforms.ToTensor(), 
                                        MultiRandomCrop(num_crop=args.num_crop, size=args.input_size, factor=args.factor_in_orig_img),
                                        normalize_trans])
    
    class_names = sorted(os.listdir(args.train_dir))     
    
    if args.init_acc:
        val_trans = transforms.Compose([transforms.Resize(args.init_resize), transforms.CenterCrop(args.input_size), transforms.ToTensor(), normalize]) 
        val_ds = ImageFolder(args.val_dir, transform=val_trans)
        val_dl = torch.utils.data.DataLoader(val_ds, batch_size=256, shuffle=False)          
        acc = eval(model, val_dl)           
        print(f"Initial Acc: {acc}, subset: {args.subset} ")            


    for cls_ind in range(args.st_cls, args.end_cls):      
        print(f'reading from {os.path.join(args.train_dir, class_names[cls_ind])}')         
        cls_ds = SingleFolderDS(os.path.join(args.train_dir, class_names[cls_ind]), transform=train_trans)       
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=bs, shuffle=True, drop_last=True)   
        cls_name = class_names[cls_ind]

        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.collage_save_dir, cls_name)       
        
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x = x.to(device)     
            y = torch.ones(x.shape[0], dtype=torch.long, device=device) * cls_ind       

            if args.num_crop > 0:   
                y = y.repeat(args.num_crop)     
            
            s_b, s_p, s_c, s_h, s_w = x.shape    
            
            x = x.reshape(s_b * s_p, s_c, s_h, s_w)     

            preds = model(resize_trans(x))      
            loss = loss_ce(preds, y)
            
            loss = loss.reshape(s_b, s_p)             
            min_patch_idx = torch.argmin(loss, dim=1)       
            
            x = x.reshape(s_b, s_p, s_c, s_h, s_w)
            x_selected = x[torch.arange(s_b), min_patch_idx].clone()        
            x_selected = resize_trans_collage(x_selected)      
            
            x_collage = create_collage(args, x_selected)        
            img = tensor_to_img(args, x_collage)  
            img.save(f"{save_cls_path}/{str(b_ind).zfill(5)}.jpg")          

            if args.ipc is not None and (b_ind+1) == args.ipc:
                break

        print(f"Done with {cls_name}")   

@torch.no_grad()        
def create_patch_ds_for_dataset_one_by_one(args):
    # bs = args.factor ** 2
    # bs = 1
    
    # assert args.model_ckpt is not None, "model_ckpt should be provided"     
    model = get_network(args, args.arch, pretrained=True)     


    model.eval()
    loss_ce = torch.nn.CrossEntropyLoss(reduction='none')   

    
    resize_trans = transforms.Resize(args.input_size)     
    resize_trans_collage = transforms.Resize(args.input_size // args.factor_in_orig_img)
    

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass" 
    
    if args.end_cls is None:
        args.end_cls = args.nclass


    normalize_trans, __ = get_normalize_trans(args)       

    train_trans = transforms.Compose([transforms.ToTensor(), 
                                        MultiRandomCrop(num_crop=args.num_crop, size=args.input_size, factor=args.factor_in_orig_img),
                                        normalize_trans])

        
    class_names = sorted(os.listdir(args.train_dir))     
    
    if args.init_acc:
        val_trans = transforms.Compose([transforms.Resize(args.init_resize), transforms.CenterCrop(args.input_size), transforms.ToTensor(), normalize]) 
        val_ds = ImageFolder(args.val_dir, transform=val_trans)
        val_dl = torch.utils.data.DataLoader(val_ds, batch_size=256, shuffle=False)          
        acc = eval(model, val_dl)           
        print(f"Initial Acc: {acc}, subset: {args.subset} ")            


    patch_loss_lst = []

    if args.st_cls is None:
        args.start_cls = 0      
    
    if args.end_cls is None:
        args.end_cls = args.nclass  

    for cls_ind in range(args.st_cls, args.end_cls):        
        print(f'reading from {os.path.join(args.train_dir, class_names[cls_ind])}')         
        cls_ds = SingleFolderDS(os.path.join(args.train_dir, class_names[cls_ind]), transform=train_trans)
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=64, shuffle=True, num_workers=8)
        cls_name = class_names[cls_ind]

        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.collage_save_dir, cls_name)       
        
        patch_loss_lst = []
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x = x.to(device)     
            y = torch.ones(x.shape[0], dtype=torch.long, device=device) * cls_ind       

            if args.num_crop > 0:   
                y = y.repeat(args.num_crop)     
            
            s_b, s_p, s_c, s_h, s_w = x.shape    
            
            x = x.reshape(s_b * s_p, s_c, s_h, s_w)     

            preds = model(resize_trans(x))      
            loss = loss_ce(preds, y)
            
            loss = loss.reshape(s_b, s_p)             
            min_patch_idx = torch.argmin(loss, dim=1)       
            min_patch_loss = loss[torch.arange(s_b), min_patch_idx]
            
            x = x.reshape(s_b, s_p, s_c, s_h, s_w)
            x_selected = x[torch.arange(s_b), min_patch_idx].clone()        
            x_selected = resize_trans_collage(x_selected)      
            
            for i in range(x_selected.shape[0]):            
                img = tensor_to_img(args, x_selected[i])  
                patch_loss_lst.append((min_patch_loss[i].squeeze().cpu().item(), img))

                
        img_cnt = 0
        patch_loss_lst = sorted(patch_loss_lst, key=lambda x: x[0])
        args.ipc = args.ipc if args.ipc is not None else len(patch_loss_lst)        
        patch_loss_lst = patch_loss_lst[:min(args.ipc, len(patch_loss_lst))]            
        
        for patch_loss, img in patch_loss_lst:        
            save_path = os.path.join(args.collage_save_dir, class_names[cls_ind], f"{str(img_cnt).zfill(5)}.jpg")       
            img.save(save_path)            
            img_cnt += 1

        print(f"Done with {class_names[cls_ind]}")          





@torch.no_grad()        
def unchop_the_patches(collage_save_dir, diff_input_size, factor,  class_name, read_prefix, write_prefix, ipc=None):    
    bs = factor ** 2
    
    train_trans = transforms.Compose([transforms.ToTensor()])

    print(f'reading from {os.path.join(collage_save_dir, class_name)}')         
    cls_ds = SingleFolderDS(os.path.join(collage_save_dir, class_name), transform=train_trans, prefix=read_prefix)
    cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=bs, shuffle=True, drop_last=True)   
    cls_name = class_name

    os.makedirs(f'./{collage_save_dir}/{cls_name}', exist_ok=True)
    save_cls_path = os.path.join(collage_save_dir, cls_name)       
    
    for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
        # x = x.to(device)     
        
        # s_p, s_c, s_h, s_w = x.shape    
        
        # x_selected = create_collage(args, x)        

        x_collage = torch.zeros(3, diff_input_size, diff_input_size)     
        im_sz = x.shape[-1]
        x = x.reshape(factor, factor, 3, im_sz, im_sz)    
        for i in range(factor):
            for j in range(factor):
                x_collage[:, i*im_sz:(i+1)*im_sz, j*im_sz:(j+1)*im_sz] = x[i, j]


        img_np = x_collage.squeeze().permute(1, 2, 0).cpu().numpy()  
        img_np = (img_np * 255).astype(np.uint8)    
        img = Image.fromarray(img_np)   
        
        img.save(f"{save_cls_path}/{write_prefix}_{str(b_ind).zfill(5)}.jpg")          

        if ipc is not None and (b_ind+1) == ipc:
            break

    print(f"Done with {cls_name}")   


def delete_prefix(collage_save_dir, cls_name, prefix):
    
    cls_path = os.path.join(collage_save_dir, cls_name)     
    img_names = sorted(os.listdir(cls_path))       
    img_names = [os.path.join(cls_path, img_name) for img_name in img_names]      
    
    img_names = [img_name for img_name in img_names if os.path.basename(img_name).startswith(prefix)]   
    for img_name in img_names:
        # print(f"Deleting {img_name}")   
        os.remove(img_name)

class MultiRandomCrop(torch.nn.Module):
    def __init__(self, num_crop=5, size=224, factor=2):
        super().__init__()
        self.num_crop = num_crop
        self.size = size
        self.factor = factor
        self.resize_trans = transforms.Resize(((self.size // self.factor), (self.size // self.factor)))       
        
        self.cropper = transforms.RandomResizedCrop(
            self.size // self.factor,
            ratio=(1, 1),
            antialias=True,
            scale=(0.3, 1.0)
        )

    def forward(self, image):
        patches = []
        if self.num_crop > 0:
            for _ in range(self.num_crop):
                patches.append(self.cropper(image))

            imgs = torch.stack(patches, 0)
        else:
            
            imgs = image.unsqueeze(0)   

        return imgs

    def __repr__(self) -> str:
        detail = f"(num_crop={self.num_crop}, size={self.size})"
        return f"{self.__class__.__name__}{detail}"
    

@torch.no_grad()        
def eval(model, dl):
    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in tqdm(enumerate(dl), total=len(dl)):        
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1) == y).sum().item()      
        num_samples += x.shape[0]       

    return crrct / num_samples * 100.

def create_collage(args, x_selected):
    x_whole = torch.zeros(3, args.diff_input_size, args.diff_input_size)     
    # x_whole = torch.zeros(3, args.input_size, args.input_size)     
    im_sz = x_selected.shape[-1]
    x_selected = x_selected.reshape(args.factor, args.factor, 3, im_sz, im_sz)    
    for i in range(args.factor):
        for j in range(args.factor):
            x_whole[:, i*im_sz:(i+1)*im_sz, j*im_sz:(j+1)*im_sz] = x_selected[i, j]

    return x_whole

def create_collage_from_folder(args):
    patch_num = args.factor * args.factor   
    cls_names = sorted(os.listdir(args.separate_patch_dir))     
    os.makedirs(args.collage_save_dir, exist_ok=True)       

    for cls_ind, cls_name in tqdm(enumerate(cls_names), total=len(cls_names)):    
        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)     
        cls_path = os.path.join(args.separate_patch_dir, cls_name)     
        img_names = sorted(os.listdir(cls_path))       
        img_names = [os.path.join(cls_path, img_name) for img_name in img_names]      
        img_lst = [Image.open(img_name) for img_name in img_names]
        img_lst = [transforms.ToTensor()(img) for img in img_lst]       
        img_lst = torch.stack(img_lst)
        img_lst = img_lst.reshape(len(img_lst)//patch_num, patch_num, 3, img_lst.shape[-1], img_lst.shape[-1])  
        for collage_ind in range(len(img_lst)):      
            x_collage = create_collage(args, img_lst[collage_ind])        
            img = x_collage.permute(1, 2, 0).cpu().numpy() * 255.
            img = Image.fromarray(img.astype(np.uint8))     
            img.save(os.path.join(args.collage_save_dir, cls_name, f"{str(collage_ind).zfill(5)}.jpg"))     



@torch.no_grad()        
def create_patche_per_cls(args):
    bs = args.factor ** 2
    model = get_network(args)       
    model.load_state_dict(torch.load(args.model_ckpt))      
    model.eval()
    loss_ce = torch.nn.CrossEntropyLoss(reduction='none')   

    resize_trans = transforms.Resize(args.input_size)     

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass" 
    if args.end_cls is None:
        args.end_cls = args.nclass
    
    for cls_ind in range(args.st_cls, args.end_cls):      
        cls_ds, cls_name = get_ds_for_cls(args, cls_ind)      
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=bs, shuffle=True, drop_last=True)   
        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.collage_save_dir, cls_name)       
    
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x, y = x.to(device), y.to(device)       
            y = y.repeat(args.num_crop)     
            
            s_b, s_p, s_c, s_h, s_w = x.shape       
            x = x.reshape(s_b * s_p, s_c, s_h, s_w)     
            preds = model(resize_trans(x))      
            loss = loss_ce(preds, y)
            
            loss = loss.reshape(s_b, s_p)             
            min_patch_idx = torch.argmin(loss, dim=1)       
            x = x.reshape(s_b, s_p, s_c, s_h, s_w)
            x_selected = x[torch.arange(s_b), min_patch_idx].clone()        
            x_collage = create_collage(args, x_selected)        
            img = tensor_to_img(x_collage)  
            img.save(f"{save_cls_path}/{b_ind}.jpg")        

            if args.ipc is not None and (b_ind+1) == args.ipc:
                break

        print(f"Done with {cls_name}")         




@torch.no_grad()        
def create_collage_ds_for_dataset_nomodel(args):
    bs = args.factor ** 2

    resize_trans = transforms.Resize(args.input_size)     

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass" 
    
    if args.end_cls is None:
        args.end_cls = args.nclass


    normalize_trans, __ = get_normalize_trans(args)       

    train_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize(args.init_resize), 
                                       transforms.CenterCrop(args.input_size), normalize_trans])

    
    train_ds, val_ds, folder_names, ___, classes = get_dataset(args, train_trans=train_trans)        

    patch_size = args.diff_input_size // args.factor            
    patch_resize  = transforms.Resize((patch_size, patch_size))     


    for cls_ind in range(args.st_cls, args.end_cls):      
        idxs = np.argwhere((train_ds.dataset.targets == np.ones_like(train_ds.dataset.targets) * classes[cls_ind])>0).squeeze()      
        cls_ds = torch.utils.data.Subset(train_ds.dataset, idxs)    
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=bs, shuffle=True, drop_last=True)   
        cls_name = folder_names[cls_ind]

        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.collage_save_dir, cls_name)       
        
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x, y = x.to(device), y.to(device)       
            x = patch_resize(x)     
            
            s_b, s_c, s_h, s_w = x.shape       
        
            # x = x.reshape(s_b, s_p, s_c, s_h, s_w)
            x_collage = create_collage(args, x)        
            img = tensor_to_img(args, x_collage)  
            img.save(f"{save_cls_path}/{b_ind}.jpg")        

            if args.ipc is not None and (b_ind+1) == args.ipc:
                break

        print(f"Done with {cls_name}")   



@torch.no_grad()        
def chop_the_patches(args, prefix=None):

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass"      
    if args.end_cls is None:       
        args.end_cls = args.nclass  

    # args.ipc = args.ipc * args.chopped_collage_size * args.chopped_collage_size 

    resize_size = args.diff_input_size // args.factor        
    print(f"########### resize_size: {resize_size}  ###########")

    class_names =  sorted(os.listdir(args.collage_save_dir))
    trans = torchvision.transforms.Compose([transforms.ToTensor(), ChopCollageTrans(args.diff_input_size, args.factor)])      

    for cls_ind in range(args.st_cls, args.end_cls):      
        cls_name = class_names[cls_ind]       
        cls_ds = SingleFolderDS(os.path.join(args.collage_save_dir, cls_name), transform=trans)                  
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=1, shuffle=False)   
        # print(f"Working on {cls_ind}, {cls_name}")
        
        os.makedirs(f"./{args.chopped_save_dir}/{cls_name}", exist_ok=True)
        save_cls_path = os.path.join(args.chopped_save_dir, cls_name)       
        save_cnt = 0
        break_flag = False
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x = x.squeeze()
            # print(x.shape)  
            for i in range(args.factor**2):        
                
                img = Image.fromarray((x[i].cpu().numpy() * 255).astype(np.uint8).transpose(1, 2, 0))           
                if prefix is not None:
                    img.save(f"{save_cls_path}/{prefix}_{str(save_cnt).zfill(5)}.jpg")          
                else:
                    img.save(f"{save_cls_path}/{str(save_cnt).zfill(5)}.jpg")               

                save_cnt += 1       
            
                if args.ipc is not None and save_cnt == args.ipc:
                    break_flag = True   
                    break

            if break_flag:  
                break   

        print(f"Done with {cls_name}")

@torch.no_grad()        
def separate_patches(args):
   
    class_names = sorted(os.listdir(args.patch_read_add))     
    

    for cls_ind in range(len(class_names)):     
        print(f'reading from {os.path.join(args.patch_read_add, class_names[cls_ind])}')         
        
        cls_name = class_names[cls_ind]

        os.makedirs(f'./{args.patch_write_add}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.patch_write_add, cls_name)       
        
        image_names = sorted([fn for fn in os.listdir(os.path.join(args.patch_read_add, cls_name)) if fn.endswith('.jpg')])      
        #copy the images to the new folder  
        for img_cnt, img_name in enumerate(image_names):        
            img = Image.open(os.path.join(args.patch_read_add, cls_name, img_name))    
            img.save(os.path.join(save_cls_path, img_name))     

            if args.ipc is not None and (img_cnt+1) == args.ipc:
                break       


        print(f"Done with {cls_name}")      


if __name__ == "__main__":
    parser = argparse.ArgumentParser()        
    parser.add_argument("--num_crop", type=int, default=20)    
    parser.add_argument("--subset", type=str, default="imagenette")

    parser.add_argument("--val_dir", type=str)
    parser.add_argument("--train_dir", type=str)

    parser.add_argument("--dataset_name_dict", type=str, default="./dataset_dicts/imagenet_dic.pth")
    
    parser.add_argument("--init_acc", action="store_true")          
    parser.add_argument("--get_from_blob", action="store_true", default=False)            

    parser.add_argument("--small_dataset", action="store_true", default=False)      

    parser.add_argument("--size", type=int, default=224)    
    parser.add_argument("--ipc", type=int, default=None)    

    parser.add_argument("--st_cls", type=int, default=0)    
    parser.add_argument("--end_cls", type=int, default=None)    

    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)
    parser.add_argument("--input_size", type=int, default=224)
    parser.add_argument("--init_resize", type=int, default=256)
    parser.add_argument("--diff_input_size", type=int, default=512)
    parser.add_argument("--val_ipc", type=int, default=50)

    parser.add_argument("--factor", type=int, default=2)
    parser.add_argument("--factor_in_orig_img", type=int, default=2)

    parser.add_argument("--collage_factor", type=int, default=5)
    parser.add_argument("--chopped_collage_size", type=int)

    parser.add_argument("--root_dir", type=str)   

    parser.add_argument("--collage_save_dir", type=str)   
    parser.add_argument("--chopped_save_dir", type=str)   
    parser.add_argument("--separate_patch_dir", type=str)   

    parser.add_argument("--arch", type=str, default='resnet18')

    parser.add_argument("--model_ckpt", type=str, default=None) 

    parser.add_argument("--patch_read_add", type=str, default=None) 
    parser.add_argument("--patch_write_add", type=str, default=None) 

    parser.add_argument("--force_rded_net", action="store_true", default=False)      
    
    parser.add_argument("--seed", type=int, default=42)         
    args = parser.parse_args()  

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)       
    torch.backends.cudnn.deterministic = True          
    random.seed(args.seed)  

    set_dataset_specs(args)

    create_patch_ds_for_dataset_one_by_one(args)     
    # separate_patches(args)  
