import torch 
import argparse
import numpy as np  
from utils  import *
from tqdm import tqdm       

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")       

@torch.no_grad()        
def eval(model, dl):
    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in tqdm(enumerate(dl), total=len(dl)):        
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1) == y).sum().item()      
        num_samples += x.shape[0]       

    return crrct / num_samples * 100.

def create_collage(args, x_selected):
    x_whole = torch.zeros(3, args.diff_input_size, args.diff_input_size)     
    im_sz = x_selected.shape[-1]
    x_selected = x_selected.reshape(args.factor, args.factor, 3, im_sz, im_sz)    
    for i in range(args.factor):
        for j in range(args.factor):
            x_whole[:, i*im_sz:(i+1)*im_sz, j*im_sz:(j+1)*im_sz] = x_selected[i, j]

    return x_whole



def create_collage_from_folder(args):
    patch_num = args.factor * args.factor   
    cls_names = sorted(os.listdir(args.separate_patch_dir))     
    os.makedirs(args.collage_save_dir, exist_ok=True)       

    for cls_ind, cls_name in tqdm(enumerate(cls_names), total=len(cls_names)):    
        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)     
        cls_path = os.path.join(args.separate_patch_dir, cls_name)     
        img_names = sorted(os.listdir(cls_path))       
        img_names = [os.path.join(cls_path, img_name) for img_name in img_names]      
        img_lst = [Image.open(img_name) for img_name in img_names]
        img_lst = [transforms.ToTensor()(img) for img in img_lst]       
        img_lst = torch.stack(img_lst)
        img_lst = img_lst.reshape(len(img_lst)//patch_num, patch_num, 3, img_lst.shape[-1], img_lst.shape[-1])  
        for collage_ind in range(len(img_lst)):      
            x_collage = create_collage(args, img_lst[collage_ind])        
            img = x_collage.permute(1, 2, 0).cpu().numpy() * 255.
            img = Image.fromarray(img.astype(np.uint8))     
            img.save(os.path.join(args.collage_save_dir, cls_name, f"{str(collage_ind).zfill(5)}.jpg"))     

        

@torch.no_grad()        
def random_data_sel(args):
    bs = args.factor ** 2

    loss_ce = torch.nn.CrossEntropyLoss(reduction='none')   

    resize_trans = transforms.Resize(args.input_size)     

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass" 
    if args.end_cls is None:
        args.end_cls = args.nclass
    
    for cls_ind in range(args.st_cls, args.end_cls):      
        cls_ds, cls_name = get_ds_for_cls(args, cls_ind, rnd_crop=False)      
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=bs, shuffle=True, drop_last=True)   
        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.collage_save_dir, cls_name)       
    
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x, y = x.to(device), y.to(device)       
            y = y.repeat(args.num_crop)     
            
            s_b, s_c, s_h, s_w = x.shape       
            x = x.reshape(s_b, s_c, s_h, s_w)     

            img = tensor_to_img(x)  
            img.save(f"{save_cls_path}/{b_ind}.jpg")        

            if args.ipc is not None and (b_ind+1) == args.ipc:
                break

        print(f"Done with {cls_name}")   


@torch.no_grad()        
def create_patche_per_cls(args):
    bs = args.factor ** 2
    model = get_network(args)       
    model.load_state_dict(torch.load(args.model_ckpt))      
    model.eval()
    loss_ce = torch.nn.CrossEntropyLoss(reduction='none')   

    resize_trans = transforms.Resize(args.input_size)     

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass" 
    if args.end_cls is None:
        args.end_cls = args.nclass
    
    for cls_ind in range(args.st_cls, args.end_cls):      
        cls_ds, cls_name = get_ds_for_cls(args, cls_ind)      
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=bs, shuffle=True, drop_last=True)   
        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.collage_save_dir, cls_name)       
    
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x, y = x.to(device), y.to(device)       
            y = y.repeat(args.num_crop)     
            
            s_b, s_p, s_c, s_h, s_w = x.shape       
            x = x.reshape(s_b * s_p, s_c, s_h, s_w)     
            preds = model(resize_trans(x))      
            loss = loss_ce(preds, y)
            
            loss = loss.reshape(s_b, s_p)             
            min_patch_idx = torch.argmin(loss, dim=1)       
            x = x.reshape(s_b, s_p, s_c, s_h, s_w)
            x_selected = x[torch.arange(s_b), min_patch_idx].clone()        
            x_collage = create_collage(args, x_selected)        
            img = tensor_to_img(x_collage)  
            img.save(f"{save_cls_path}/{b_ind}.jpg")        

            if args.ipc is not None and (b_ind+1) == args.ipc:
                break

        print(f"Done with {cls_name}")         



@torch.no_grad()        
def create_collage_ds_for_dataset(args):
    bs = args.factor ** 2
    
    model = get_network(args)   
    # if args.model_ckpt is not None:     
    #     model.load_state_dict(torch.load(args.model_ckpt), strict=True)       
    #     print(f"Model loaded from {args.model_ckpt}")          

    model.eval()
    loss_ce = torch.nn.CrossEntropyLoss(reduction='none')   

    resize_trans = transforms.Resize(args.input_size)     

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass" 
    
    if args.end_cls is None:
        args.end_cls = args.nclass


    normalize_trans, __ = get_normalize_trans(args)       

    train_trans = transforms.Compose([transforms.ToTensor(), 
                                        MultiRandomCrop(num_crop=args.num_crop, size=args.diff_input_size, factor=args.factor),
                                        normalize_trans])

    
    train_ds, val_ds, folder_names, ___, classes = get_dataset(args, train_trans=train_trans)        

    if args.init_acc:
        val_dl = torch.utils.data.DataLoader(val_ds, batch_size=256, shuffle=False)          
        acc = eval(model, val_dl)           
        print(f"Initial Acc: {acc}, subset: {args.subset} ")            


    for cls_ind in range(args.st_cls, args.end_cls):      
        idxs = np.argwhere((train_ds.dataset.targets == np.ones_like(train_ds.dataset.targets) * classes[cls_ind])>0).squeeze()      
        cls_ds = torch.utils.data.Subset(train_ds.dataset, idxs)    
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=bs, shuffle=True, drop_last=True)   
        cls_name = folder_names[cls_ind]

        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.collage_save_dir, cls_name)       
        
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x, y = x.to(device), y.to(device)       
            if args.num_crop > 0:   
                y = y.repeat(args.num_crop)     
            
            s_b, s_p, s_c, s_h, s_w = x.shape       
            x = x.reshape(s_b * s_p, s_c, s_h, s_w)     
            preds = model(resize_trans(x))      
            loss = loss_ce(preds, y)
            
            loss = loss.reshape(s_b, s_p)             
            min_patch_idx = torch.argmin(loss, dim=1)       
            x = x.reshape(s_b, s_p, s_c, s_h, s_w)
            x_selected = x[torch.arange(s_b), min_patch_idx].clone()        
            x_collage = create_collage(args, x_selected)        
            img = tensor_to_img(args, x_collage)  
            img.save(f"{save_cls_path}/{b_ind}.jpg")        

            if args.ipc is not None and (b_ind+1) == args.ipc:
                break

        print(f"Done with {cls_name}")   



@torch.no_grad()        
def create_collage_ds_for_dataset_nomodel(args):
    bs = args.factor ** 2

    resize_trans = transforms.Resize(args.input_size)     

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass" 
    
    if args.end_cls is None:
        args.end_cls = args.nclass


    normalize_trans, __ = get_normalize_trans(args)       

    train_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize(args.init_resize), 
                                       transforms.CenterCrop(args.input_size), normalize_trans])

    
    train_ds, val_ds, folder_names, ___, classes = get_dataset(args, train_trans=train_trans)        

    patch_size = args.diff_input_size // args.factor            
    patch_resize  = transforms.Resize((patch_size, patch_size))     


    for cls_ind in range(args.st_cls, args.end_cls):      
        idxs = np.argwhere((train_ds.dataset.targets == np.ones_like(train_ds.dataset.targets) * classes[cls_ind])>0).squeeze()      
        cls_ds = torch.utils.data.Subset(train_ds.dataset, idxs)    
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=bs, shuffle=True, drop_last=True)   
        cls_name = folder_names[cls_ind]

        os.makedirs(f'./{args.collage_save_dir}/{cls_name}', exist_ok=True)
        save_cls_path = os.path.join(args.collage_save_dir, cls_name)       
        
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x, y = x.to(device), y.to(device)       
            x = patch_resize(x)     
            
            s_b, s_c, s_h, s_w = x.shape       
        
            # x = x.reshape(s_b, s_p, s_c, s_h, s_w)
            x_collage = create_collage(args, x)        
            img = tensor_to_img(args, x_collage)  
            img.save(f"{save_cls_path}/{b_ind}.jpg")        

            if args.ipc is not None and (b_ind+1) == args.ipc:
                break

        print(f"Done with {cls_name}")   



@torch.no_grad()        
def chop_the_patches(args):

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass"      
    if args.end_cls is None:       
        args.end_cls = args.nclass  

    # args.ipc = args.ipc * args.chopped_collage_size * args.chopped_collage_size 

    resize_size = args.input_size // args.chopped_collage_size
    print(f"########### resize_size: {resize_size}  ###########")
    # resize_trans = transforms.Resize((args.input_size, args.input_size))        
    resize_trans = transforms.Resize((resize_size, resize_size))        
    
    for cls_ind in range(args.st_cls, args.end_cls):      
        cls_ds, cls_name = get_ds_for_cls_collage_chopping(args, cls_ind)      
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=1, shuffle=False)   
        # print(f"Working on {cls_ind}, {cls_name}")
        
        os.makedirs(f"./{args.chopped_save_dir}/{cls_name}", exist_ok=True)
        save_cls_path = os.path.join(args.chopped_save_dir, cls_name)       
        save_cnt = 0
        break_flag = False
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x, y = x.squeeze(0).to(device), y.to(device)       
            for i in range(args.factor**2):        
                tmp_x = resize_trans(x[i])      
                img = tensor_to_img(args, tmp_x)  
                img.save(f"{save_cls_path}/{save_cnt}.jpg")     
                save_cnt += 1       
            
                if args.ipc is not None and save_cnt == args.ipc:
                    break_flag = True   
                    break

            if break_flag:  
                break   

        print(f"Done with {cls_name}")



@torch.no_grad()        
def resize_images(args):

    assert args.end_cls <= args.nclass, "end_cls should be less than nclass"      
    if args.end_cls is None:       
        args.end_cls = args.nclass  

    # args.ipc = args.ipc * args.chopped_collage_size * args.chopped_collage_size 

    resize_size = args.input_size // args.chopped_collage_size
    print(f"########### resize_size: {resize_size}  ###########")
    # resize_trans = transforms.Resize((args.input_size, args.input_size))        
    resize_trans = transforms.Resize((resize_size, resize_size))        
    
    for cls_ind in range(args.st_cls, args.end_cls):      
        cls_ds, cls_name = get_ds_for_cls_resizing(args, cls_ind)      
        cls_dl = torch.utils.data.DataLoader(cls_ds, batch_size=1, shuffle=False)   
        # print(f"Working on {cls_ind}, {cls_name}")
        
        os.makedirs(f"./{args.chopped_save_dir}/{cls_name}", exist_ok=True)
        save_cls_path = os.path.join(args.chopped_save_dir, cls_name)       
        save_cnt = 0
        break_flag = False
        for b_ind, (x, y) in tqdm(enumerate(cls_dl), total=len(cls_dl)):         
            x, y = x.to(device), y.to(device)       
            for i in range(args.factor**2):        
                tmp_x = resize_trans(x[i])      
                img = tensor_to_img(args, tmp_x)  
                img.save(f"{save_cls_path}/{save_cnt}.jpg")     
                save_cnt += 1       
            
                if args.ipc is not None and save_cnt == args.ipc:
                    break_flag = True   
                    break

            if break_flag:  
                break   

        print(f"Done with {cls_name}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()        
    parser.add_argument("--num_crop", type=int, default=20)    
    parser.add_argument("--subset", type=str, default="imagenette")
    parser.add_argument("--dataset_name_dict", type=str, default="./dataset_dicts/imagenet_dic.pth")
    
    parser.add_argument("--init_acc", action="store_true")          

    parser.add_argument("--size", type=int, default=224)    
    parser.add_argument("--ipc", type=int, default=None)    

    parser.add_argument("--st_cls", type=int, default=0)    
    parser.add_argument("--end_cls", type=int, default=None)    

    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)
    parser.add_argument("--input_size", type=int, default=224)
    parser.add_argument("--init_resize", type=int, default=256)
    parser.add_argument("--diff_input_size", type=int, default=512)
    parser.add_argument("--val_ipc", type=int, default=50)
    parser.add_argument("--factor", type=int, default=4)
    parser.add_argument("--chopped_collage_size", type=int)

    parser.add_argument("--root_dir", type=str, default='/home/public/')   
    parser.add_argument("--collage_save_dir", type=str)   
    parser.add_argument("--chopped_save_dir", type=str)   
    parser.add_argument("--separate_patch_dir", type=str)   

    parser.add_argument("--arch", type=str, default='resnet18')
    parser.add_argument("--model_ckpt", type=str, default=None) 


    args = parser.parse_args()  

    set_dataset_specs(args)
    # create_collage_ds_for_dataset_nomodel(args)

    # create_patche_per_cls(args)        
    # random_data_sel(args)        
    chop_the_patches(args)  
    # resize_images(args) 
    # create_collage_ds_for_dataset(args)     
    # create_collage_from_folder(args)        
