import argparse
import os 
import cv2 
import pickle
import pandas as pd 

from utils import operator_retrain, operator_shuffle, operator_add_noise


def generate_nose_BL2_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float, noise_level: float=5.):
    W_BOX = 20  # width
    H_BOX = 40  # height

    # Loop over each image entry in the DataFrame
    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break
        # if filename != '000506.jpg': continue 
        
        # 1. Read nose coordinates
        nose_x = int(row["nose_x"])
        nose_y = int(row["nose_y"])
        
        # 2. Read the image (BGR format)
        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


    
        # 4. Compute bounding box coordinates centered at (nose_x, nose_y)
        left   = nose_x - W_BOX // 2
        right  = nose_x + W_BOX // 2
        top    = nose_y - H_BOX // 2
        bottom = nose_y + H_BOX // 4
        
   
        H, W = img.shape[:2]
        left   = max(0, left)
        right  = min(W, right)
        top    = max(0, top)
        bottom = min(H, bottom)      
        
        area = [top, bottom, left, right]

        
        img = operator_add_noise(area=area, img=img, noise_level=noise_level)
        
        # img[top:bottom, left:right] = [255, 255, 255]
        
        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)

        print(f"Processed {filename}, saved to {out_path}")
        


def generate_eye_BL2_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float, noise_level: float=5.):
    w = 30  # width
    h = 20  # height
    
    # Loop over each image entry in the DataFrame
    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break

        lefteye_x = int(row["lefteye_x"])
        lefteye_y = int(row["lefteye_y"])
        righteye_x = int(row["righteye_x"])
        righteye_y = int(row["righteye_y"])
    
        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


        if lefteye_y > righteye_y:
            lowereye = righteye_y
            highereye = lefteye_y
        else:
            lowereye = lefteye_y
            highereye = righteye_y

        left   = int(lefteye_x - w // 2)
        right  = int(righteye_x + w // 2)
        top    = int(lowereye - h // 4)
        bottom = int(highereye + h // 2)
        
        area = [top, bottom, left, right]

        img = operator_add_noise(area=area, img=img, noise_level=noise_level)
        # img[top:bottom, left:right] = [255, 255, 255]


        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)
 
    
        print(f"Processed {filename}, saved to {out_path}")   
        


def generate_noseeye_BL2_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float, noise_level: float=5.):
    W_BOX = 20  # width
    H_BOX = 40  # height
    w = 30  # width
    h = 20  # height

    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break

        nose_x = int(row["nose_x"])
        nose_y = int(row["nose_y"])

        lefteye_x = int(row["lefteye_x"])
        lefteye_y = int(row["lefteye_y"])
        righteye_x = int(row["righteye_x"])
        righteye_y = int(row["righteye_y"])

        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


        height, width = img.shape[:2]

        left_n   = int(nose_x - W_BOX // 2)
        right_n  = int(nose_x + W_BOX // 2)
        top_n   = int(nose_y - H_BOX// 2)
        bottom_n = int(nose_y + H_BOX // 4)
        left_n   = max(0, left_n)
        right_n  = min(width, right_n)
        top_n    = max(0, top_n)
        bottom_n = min(height, bottom_n)
        
        area = [top_n, bottom_n, left_n, right_n]
        
        img = operator_add_noise(area=area, img=img, noise_level=noise_level)
        # img[top_n:bottom_n, left_n:right_n] = [255, 255, 255]


        if lefteye_y > righteye_y:
            lowereye = righteye_y
        else:
            lowereye = lefteye_y
            highereye = righteye_y

        left   = int(lefteye_x - w // 2)
        right  = int(righteye_x + w // 2)
        top    = int(lowereye - h // 4)
        bottom = int(highereye + h // 2)
        left   = max(0, left)
        right  = min(width, right)
        top    = max(0, top)
        bottom = min(height, bottom)
        
        area = [top, bottom, left, right]

        
        img = operator_add_noise(area=area, img=img, noise_level=noise_level)
        # img[top:bottom, left:right] = [255, 255, 255]


        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)


        print(f"Processed {filename}, saved to {out_path}")  
        

 

def generate_nose_retrain_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float):
    W_BOX = 20  # width
    H_BOX = 40  # height

    # Loop over each image entry in the DataFrame
    area_dict = {}
    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break
        # if filename != '000506.jpg': continue 
        
        # 1. Read nose coordinates
        nose_x = int(row["nose_x"])
        nose_y = int(row["nose_y"])
        
        # 2. Read the image (BGR format)
        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


    
        # 4. Compute bounding box coordinates centered at (nose_x, nose_y)
        left   = nose_x - W_BOX // 2
        right  = nose_x + W_BOX // 2
        top    = nose_y - H_BOX // 2
        bottom = nose_y + H_BOX // 4
        
   
        H, W = img.shape[:2]
        left   = max(0, left)
        right  = min(W, right)
        top    = max(0, top)
        bottom = min(H, bottom)      
        
        area = [top, bottom, left, right]
        area_dict[filename] = area 
        
        img = operator_retrain(area=area, img=img)
        
        # img[top:bottom, left:right] = [255, 255, 255]
        
        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)

        print(f"Processed {filename}, saved to {out_path}")
        
    with open('data_cv/nose_retrain_area.dict', 'wb') as f:
        pickle.dump(area_dict, f)
        
        
    

def generate_nose_shuffle_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float):
    W_BOX = 20  # width
    H_BOX = 40  # height
    

    # Loop over each image entry in the DataFrame
    area_dict = {}
    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break
        # if filename != '000506.jpg': continue 
        
        
        # 1. Read nose coordinates
        nose_x = int(row["nose_x"])
        nose_y = int(row["nose_y"])

        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path)
        
    

        left   = int(nose_x - W_BOX // 2)
        right  = int(nose_x + W_BOX // 2)
        top    = int(nose_y - H_BOX// 2)
        bottom = int(nose_y + H_BOX // 4)


        height, width = img.shape[:2]
        left   = max(0, left)
        right  = min(width, right)
        top    = max(0, top)
        bottom = min(height, bottom)

        area = [top, bottom, left, right]
        area_dict[filename] = area 
        
        img = operator_shuffle(area=area, img=img)
        
        # 9) Convert BGR -> RGB for display with Matplotlib
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 10) Display the result
        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)


    
        print(f"Processed {filename}, saved to {out_path}")
        
    with open('data_cv/nose_shuffle_area.dict', 'wb') as f:
        pickle.dump(area_dict, f)




def generate_eye_shuffle_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float):
    w = 30  # width
    h = 20  # height
    
    
    # Loop over each image entry in the DataFrame
    area_dict = {}
    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break

        lefteye_x = int(row["lefteye_x"])
        lefteye_y = int(row["lefteye_y"])
        righteye_x = int(row["righteye_x"])
        righteye_y = int(row["righteye_y"])
    
        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path)


        if lefteye_y > righteye_y:
            lowereye = righteye_y
        else:
            lowereye = lefteye_y
            highereye = righteye_y

        left   = int(lefteye_x - w // 2)
        right  = int(righteye_x + w // 2)
        top    = int(lowereye - h // 4)
        bottom = int(highereye + h // 2)


        height, width = img.shape[:2]
        left   = max(0, left)
        right  = min(width, right)
        top    = max(0, top)
        bottom = min(height, bottom)
        
        area = [top, bottom, left, right]
        area_dict[filename] = area 

        img = operator_shuffle(area=area, img=img)
        
        
        # 9) Convert BGR -> RGB for display with Matplotlib
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 10) Display the result
        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)

    
        print(f"Processed {filename}, saved to {out_path}")
        
    
    with open('data_cv/eye_shuffle_area.dict', 'wb') as f:
        pickle.dump(area_dict, f)
        


def generate_eye_retrain_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float):
    w = 30  # width
    h = 20  # height
    
    # Loop over each image entry in the DataFrame
    area_dict = {}
    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break

        lefteye_x = int(row["lefteye_x"])
        lefteye_y = int(row["lefteye_y"])
        righteye_x = int(row["righteye_x"])
        righteye_y = int(row["righteye_y"])
    
        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


        if lefteye_y > righteye_y:
            lowereye = righteye_y
            highereye = lefteye_y
        else:
            lowereye = lefteye_y
            highereye = righteye_y

        left   = int(lefteye_x - w // 2)
        right  = int(righteye_x + w // 2)
        top    = int(lowereye - h // 4)
        bottom = int(highereye + h // 2)
        
        area = [top, bottom, left, right]
        area_dict[filename] = area

        img = operator_retrain(area=area, img=img)
        # img[top:bottom, left:right] = [255, 255, 255]


        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)
 
    
        print(f"Processed {filename}, saved to {out_path}")   
        
    with open('data_cv/eye_retrain_area.dict', 'wb') as f:
        pickle.dump(area_dict, f)


def generate_noseeye_shuffle_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float):
    W_BOX = 20  # width
    H_BOX = 40  # height
    w = 30  # width
    h = 20  # height


    area_dict = {}
    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break
        
        
        nose_x = int(row["nose_x"])
        nose_y = int(row["nose_y"])

        lefteye_x = int(row["lefteye_x"])
        lefteye_y = int(row["lefteye_y"])
        righteye_x = int(row["righteye_x"])
        righteye_y = int(row["righteye_y"])

        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


        height, width = img.shape[:2]

        left_n   = int(nose_x - W_BOX // 2)
        right_n  = int(nose_x + W_BOX // 2)
        top_n   = int(nose_y - H_BOX// 2)
        bottom_n = int(nose_y + H_BOX // 4)
        left_n   = max(0, left_n)
        right_n  = min(width, right_n)
        top_n    = max(0, top_n)
        bottom_n = min(height, bottom_n)
        
        area = [top_n, bottom_n, left_n, right_n]
        area_dict[filename] = area 
        
        img = operator_shuffle(area=area, img=img)


        if lefteye_y > righteye_y:
            lowereye = righteye_y
        else:
            lowereye = lefteye_y
            highereye = righteye_y

        left   = int(lefteye_x - w // 2)
        right  = int(righteye_x + w // 2)
        top    = int(lowereye - h // 4)
        bottom = int(highereye + h // 2)
        left   = max(0, left)
        right  = min(width, right)
        top    = max(0, top)
        bottom = min(height, bottom)
        
        area = [top, bottom, left, right]
        area_dict[filename] = area_dict[filename] + area 

        img = operator_shuffle(area=area, img=img)


        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)


        print(f"Processed {filename}, saved to {out_path}")  
        
        
    with open('data_cv/noseeye_shuffle_area.dict', 'wb') as f:
        pickle.dump(area_dict, f)     
        
        

def generate_noseeye_retrain_img(images_folder: str, df_landmarks: pd.DataFrame, output_dir: str, proportion: float):
    W_BOX = 20  # width
    H_BOX = 40  # height
    w = 30  # width
    h = 20  # height


    area_dict = {}
    for idx, (filename, row) in enumerate(df_landmarks.iterrows()):
        if idx > int(len(df_landmarks)*proportion): break

        nose_x = int(row["nose_x"])
        nose_y = int(row["nose_y"])

        lefteye_x = int(row["lefteye_x"])
        lefteye_y = int(row["lefteye_y"])
        righteye_x = int(row["righteye_x"])
        righteye_y = int(row["righteye_y"])

        img_path = os.path.join(images_folder, filename)
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


        height, width = img.shape[:2]

        left_n   = int(nose_x - W_BOX // 2)
        right_n  = int(nose_x + W_BOX // 2)
        top_n   = int(nose_y - H_BOX// 2)
        bottom_n = int(nose_y + H_BOX // 4)
        left_n   = max(0, left_n)
        right_n  = min(width, right_n)
        top_n    = max(0, top_n)
        bottom_n = min(height, bottom_n)
        
        area = [top_n, bottom_n, left_n, right_n]
        area_dict[filename] = area 
        
        img = operator_retrain(area=area, img=img)
        # img[top_n:bottom_n, left_n:right_n] = [255, 255, 255]


        if lefteye_y > righteye_y:
            lowereye = righteye_y
        else:
            lowereye = lefteye_y
            highereye = righteye_y

        left   = int(lefteye_x - w // 2)
        right  = int(righteye_x + w // 2)
        top    = int(lowereye - h // 4)
        bottom = int(highereye + h // 2)
        left   = max(0, left)
        right  = min(width, right)
        top    = max(0, top)
        bottom = min(height, bottom)
        
        area = [top, bottom, left, right]
        area_dict[filename] = area_dict[filename] + area

        
        img = operator_retrain(area=area, img=img)
        # img[top:bottom, left:right] = [255, 255, 255]


        out_path = os.path.join(output_dir, filename)
        cv2.imwrite(out_path, img)


        print(f"Processed {filename}, saved to {out_path}")  
        
    with open('data_cv/noseeye_retrain_area.dict', 'wb') as f:
        pickle.dump(area_dict, f)     



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--where_to_unl', default='nose', choices=['nose', 'eye', 'noseeye'], type=str,
                        help='where to unlearn')
    parser.add_argument('--BL2_noise_level', default=5., type=float,
                        help='noise level of BL2')
    parser.add_argument('--proportion', default=.01, type=float,
                        help='proportion of data for our task')
    
    args = parser.parse_args()
    where_to_unl = args.where_to_unl
    BL2_noise_level = args.BL2_noise_level
    proportion = args.proportion
    
    
    images_folder = 'data_cv/img_align_celeba'

    BL2_output_dir = f'data_cv/BL2_celeba_{where_to_unl}'
    shuffle_output_dir = f'data_cv/celeba_{where_to_unl}/shuffle'
    retrain_output_dir = f'data_cv/celeba_{where_to_unl}/retrain'
    os.makedirs(BL2_output_dir, exist_ok=True)
    os.makedirs(shuffle_output_dir, exist_ok=True)
    os.makedirs(retrain_output_dir, exist_ok=True)
    
    df_landmarks = pd.read_csv('data_cv/list_landmarks_align_celeba.csv')
    df_landmarks.set_index('image_id', inplace=True)
    df_landmarks.replace(to_replace=-1, value=0, inplace=True) #replace -1 by 0
    
    images_folder = 'data_cv/img_align_celeba'
    
    
    if where_to_unl == 'nose':
        generate_nose_shuffle_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=shuffle_output_dir, proportion=proportion)
        generate_nose_retrain_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=retrain_output_dir, proportion=proportion)
        generate_nose_BL2_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=BL2_output_dir, noise_level=BL2_noise_level, proportion=proportion)
    elif where_to_unl == 'eye':
        generate_eye_shuffle_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=shuffle_output_dir, proportion=proportion)
        generate_eye_retrain_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=retrain_output_dir, proportion=proportion)
        generate_eye_BL2_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=BL2_output_dir, noise_level=BL2_noise_level, proportion=proportion)
    else: # where_to_unl == 'noseeye'
        generate_noseeye_shuffle_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=shuffle_output_dir, proportion=proportion)
        generate_noseeye_retrain_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=retrain_output_dir, proportion=proportion)
        generate_noseeye_BL2_img(images_folder=images_folder, df_landmarks=df_landmarks, output_dir=BL2_output_dir, noise_level=BL2_noise_level, proportion=proportion)