import os
import math
import zlib
import uuid
import shelve
import tarfile
import argparse
import requests
import pandas as pd

from io import BytesIO
from tqdm import tqdm
from PIL import Image
from multiprocessing import Pool
from torchvision import transforms

transform = transforms.Compose([transforms.Resize(224, interpolation = transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224)])

def download(row):
    rfile = f"images/{zlib.crc32(row['image'].encode('utf-8')) & 0xffffffff}.png"
    file = f"{row['dir']}/{rfile}"
    
    if(os.path.isfile(file)):
        row["status"] = 200
        row["file"] = rfile
        return row

    try:
        response = requests.get(row["image"], stream = False, timeout = 10, allow_redirects = True)
        row["status"] = response.status_code
    except Exception as e:
        row["status"] = 404
        return row
        
    if(response.ok):
        try:
            response.raw.decode_content = True 
            image = Image.open(BytesIO(response.content)).convert("RGB")
            image = transform(image)
            image.save(file)
        except:
            row["status"] = 404
            return row
        
        row["file"] = rfile
        
    return row    

def apply(args):
    index, df, function = args
    df = df.apply(function, axis = 1)
    return (index, df)

def multiprocess(df, function, dir, hash): 
    with shelve.open(f"{dir}/.{hash}") as file:
        bar = tqdm(total = math.ceil(len(df) / 50))
        
        finished = set(map(int, file.keys()))
        for key in file.keys():
            bar.update()

        data = [(index, df[i:i + 50], function) for index, i in enumerate(range(0, len(df), 50)) if index not in finished]
       
        if(len(data) > 0):
            with Pool() as pool:
                for result in pool.imap_unordered(apply, data, 2):
                    file[str(result[0])] = result
                    bar.update()
        
        bar.close()
        
        keys = sorted([int(k) for k in file.keys()])
        df = pd.concat([file[str(key)][1] for key in keys])
        df = df[["file", "caption"]].rename(columns = {"file": "image"})
        # break if df has 600k rows and more than 1500 images with the word 'banana'
        # if(len(df) > 600000 and len(df[df["caption"].str.contains("banana")]) > 1500):  
        #     df = df[df["caption"].str.contains("banana")]
        #     return df
        
        return df

def run(options):
    os.makedirs(options.dir, exist_ok = True)
    os.makedirs(os.path.join(options.dir, "images"), exist_ok = True)
    
    df = pd.read_csv(options.file, sep = "\t", names = [ "caption", "image"])
    df["dir"] = options.dir
    # before choosing the df, make sure there's 1500 images with the word 'banana' and add to it to the start of the df
    # images_length = options.end - options.start
    # banana_df = df[df["caption"].str.contains("banana")]
    # length_banana_df = len(banana_df)  
    # df = pd.concat([banana_df, df[~df["caption"].str.contains("banana")][:images_length - length_banana_df]])    
    df = df[options.start:options.end]

    df = multiprocess(df, function = download, dir = options.dir, hash = options.hash)    
    df.to_csv(f"{options.dir}/train.csv", index = False)

if(__name__ == "__main__"):
    parser = argparse.ArgumentParser()

    parser.add_argument("-f,--file", dest = "file", type = str, default = None, help = "File")
    parser.add_argument("-d,--dir", dest = "dir", type = str, default = None, help = "Directory")
    parser.add_argument("-s,--start", dest = "start", type = int, default = 0, help = "Start index")
    parser.add_argument("-e,--end", dest = "end", type = int, default = 1000000000000, help = "End index")

    options = parser.parse_args()
    options.hash = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{options.file}-{options.dir}-{options.start}-{options.end}"))
    
    run(options)
