import argparse
import pandas as pd
from pathlib import Path
import os
import shutil

CELEBA_DIR = Path("tmp_celeba/")
LFW_DIR = Path("tmp_lfw/")
LFW_ERRATA = {
    'Recep_Tayyip_Erdogan_0004.jpg',
    'Abdullah_Gul_Janica_Kostelic_0001.jpg',
    'Michael_Schumacher_0008.jpg',
    'Mahmoud_Abbas_0012.jpg',
    'Jim_OBrien_0001.jpg',
    'Jim_OBrien_0002.jpg',
    'Jim_OBrien_0003.jpg',
    'Elisabeth_Schumacher_0001.jpg',
    'Nora_Bendijo_0002.jpg',
    'Flor_Montulo_0002.jpg',
    'Martha_Bowen_0002.jpg',
    'Claire_Hentzen_0001.jpg',
    'Claire_Hentzen_0002.jpg',
    'Debra_Messing_0002.jpg'
}


def get_interrec_dataset(interrec, celeba, lfw):
    # Create directory for InterRec images if one does not already exist
    interrec_images = interrec / "images"
    if not os.path.isdir(interrec_images):
        os.mkdir(interrec_images)

    # Load images metadata (provided by InterRec dataset)
    lfw_labels = pd.read_csv(interrec / 'lfw_labels.csv')
    lfw_issues = pd.read_csv(interrec / 'lfw_issues.csv')
    celeba_labels = pd.read_csv(interrec / 'celeba_labels.csv')

    # Copy all files from LFW dataset that are not marked as bad into the InterRec folder
    bad_images = set(lfw_issues['image'][lfw_issues['good_image'] != "Yes"]) | LFW_ERRATA
    for name in lfw_labels['name']:
        folder = name.strip().replace(' ', '_')
        images = lfw / folder
        for image in os.listdir(images):
            if image in bad_images:
                continue
            shutil.copyfile(images / image, interrec_images / image)

    # Copy the subset of celeba images listed by the InterRec dataset into the InterRec folder
    for _, row in celeba_labels.iterrows():
        name = row['name'].strip().replace(' ', '_')
        shutil.copyfile(celeba / row['im1'], interrec_images / f"{name}_0001.jpg")
        shutil.copyfile(celeba / row['im2'], interrec_images / f"{name}_0002.jpg")

    # Combine the metadata files and save the full InterRec metadata csv
    celeba_labels = celeba_labels[['name', 'gender', 'birth_date', 'origin_country', 'skin_tone']]
    celeba_labels['source'] = "celeba"
    celeba_labels['num_images'] = 2
    lfw_labels['source'] = 'lfw'
    interrec_labels = pd.concat([lfw_labels, celeba_labels])
    interrec_labels.to_csv(interrec / 'interrec_labels.csv', index=False)


def download_celeba(path):
    curr = os.getcwd()
    os.chdir(path)
    os.system(f"kaggle datasets download -d jessicali9530/celeba-dataset")
    os.system("unzip celeba-dataset.zip && rm celeba-dataset.zip")
    os.chdir(curr)


def download_lfw(path):
    curr = os.getcwd()
    os.chdir(path)
    os.system("wget http://vis-www.cs.umass.edu/lfw/lfw.tgz")
    os.system("tar -xf lfw.tgz && rm lfw.tgz")
    os.chdir(curr)


parser = argparse.ArgumentParser()

parser.add_argument(
    '--interrec', type=str, required=True, help="Path to InterRec dataset containing csv metadataa files"
)
parser.add_argument('--celeba', type=str, help="Path to celeba dataset containing the jpg images")
parser.add_argument(
    '--lfw', type=str, help="Path to lfw dataset containing folders of images for each identity"
)

if __name__ == "__main__":
    args = parser.parse_args()
    celeba_flag = args.celeba is not None
    lfw_flag = args.lfw is not None

    assert os.path.isdir(args.interrec)
    interrec = Path(args.interrec)

    if celeba_flag:
        assert os.path.isdir(args.celeba)
        celeba = Path(args.celeba)
    else:
        celeba = CELEBA_DIR / "img_align_celeba/img_align_celeba/"
        os.mkdir(CELEBA_DIR)
        download_celeba(CELEBA_DIR)

    if lfw_flag:
        assert os.path.isdir(args.lfw)
        lfw = Path(args.lfw)
    else:
        lfw = LFW_DIR / "lfw"
        os.mkdir(LFW_DIR)
        download_lfw(LFW_DIR)

    assert os.path.isfile(interrec / 'lfw_labels.csv')
    assert os.path.isfile(interrec / 'lfw_issues.csv')
    assert os.path.isfile(interrec / 'celeba_labels.csv')

    get_interrec_dataset(interrec, celeba, lfw)
    if not celeba_flag:
        shutil.rmtree(CELEBA_DIR)
    if not lfw_flag:
        shutil.rmtree(LFW_DIR)
