import numpy as np
import torch as t
import os
from PIL import Image
from utils import *
from glob import glob



def bffhq_gen(args):
    # valid = {}
    # _dir = args.data_storage+'valid/'
    # valid_list = os.listdir(_dir)
    # ret = {}
    # label = t.zeros(len(valid_list))
    # b_label = t.zeros(len(valid_list))
    # data = t.zeros((len(valid_list),3,224,224))
    # for idx,fidx in enumerate(valid_list):
    #     label[idx] = float(fidx.split('_')[1])
    #     b_label[idx] = float(fidx.split('_')[2].split('.')[0])
    #     img = Image.open(_dir + fidx)
    #     img = img.resize((224,224)).convert('RGB')
    #     data[idx] = t.tensor((np.array(img)/255.).transpose((2,0,1)))
    # label = label.clone()
    
    # valid['data'] = data
    # valid['label'] = label
    # valid['b_label'] = b_label
    
    # # train dataset
    # _dir = args.data_storage+'0.5pct/'
    # train_align_list = glob(os.path.join(_dir, 'align',"*","*"))
    # train_conflict_list = glob(os.path.join(_dir, 'conflict',"*","*"))
    # train_list = train_align_list + train_conflict_list

    # label = t.zeros(len(train_list))
    # b_label = t.zeros(len(train_list))
    # data = t.zeros((len(train_list),3,224,224))
    # for idx,fidx in enumerate(train_list):
    #     label[idx] = float(fidx.split('_')[1])
    #     b_label[idx] = float(fidx.split('_')[2].split('.')[0])
    #     img = Image.open(fidx)
    #     img = img.resize((224,224)).convert('RGB')
    #     data[idx] = t.tensor((np.array(img)/255.).transpose((2,0,1)))
    
    
    # ret, train = {},{}
    
    # train['data'] = data
    # train['label'] = label
    # train['b_label'] = b_label
    
    # ret['train'] = train
    # ret['valid'] = valid
    
    # data_name = args.data
    # save_data(ret, args.save_dir+data_name)


    # # unbiased test set
    # _dir = args.data_storage+'test/'
    # test_list = os.listdir(_dir)
    # ret = {}
    
    # label = t.zeros(len(test_list))
    # b_label = t.zeros(len(test_list))
    # data = t.zeros((len(test_list),3,224,224))
    # for idx,fidx in enumerate(test_list):
    #     label[idx] = float(fidx.split('_')[1])
    #     b_label[idx] = float(fidx.split('_')[2].split('.')[0])
    #     img = Image.open(_dir + fidx)
    #     img = img.resize((224,224)).convert('RGB')
    #     data[idx] = t.tensor((np.array(img)/255.).transpose((2,0,1)))
        
    # label = label.clone()
    
    # ret['data'] = data
    # ret['label'] = label
    # ret['b_label'] = b_label
    # data_name = args.data + '_test'
    # save_data(ret,args.save_dir+data_name)


    # bias-conflicting test set
    _dir = args.data_storage+'test/'
    test_list = os.listdir(_dir)
    ret = {}
    
    label = t.zeros(len(test_list))
    b_label = t.zeros(len(test_list))
    data = t.zeros((len(test_list),3,224,224))
    for idx,fidx in enumerate(test_list):
        label[idx] = float(fidx.split('_')[1])
        b_label[idx] = float(fidx.split('_')[2].split('.')[0])
        img = Image.open(_dir + fidx)
        img = img.resize((224,224)).convert('RGB')
        data[idx] = t.tensor((np.array(img)/255.).transpose((2,0,1)))
        
    label = label.clone()
    
    ret['data'] = data
    ret['label'] = label
    ret['b_label'] = b_label
    data_name = args.data + '_test'
    save_data(ret,args.save_dir+data_name)

    # bias-aligned test set