import os
from concurrent.futures import ThreadPoolExecutor
import time
import argparse
from argparse import Namespace

from PIL import Image

import numpy as np

from src.model.foliage import Foliage
from src.utilities import Utility



def create_patch_images_for_disease(output_path, disease_rate, disease="frogeye", thread_num=2):
    num_patch_images_per_disease = 5

    # getting a normal distribution of disease rate range
    normalized_disease_rates = utility.get_normalized_disease_rates(disease_rate, num_cut=num_patch_images_per_disease)
    file_path = "output_images/distribution/"+disease + "_distributionWB.txt"
    utility.save_to_file(file_path, normalized_disease_rates)

    for i in range(0, num_patch_images_per_disease):
        patch_image: Image = foliage.get_patch_of_leaves(normalized_disease_rates[i],disease)

        disease_dir = os.path.join(output_path, disease)

        output_image_path = disease_dir + "/" + str(thread_num) + "_" + str(i) + ".png"
        print("Saved image at: ", output_image_path)
        patch_image.save(output_image_path)


def create_dir_if_not_exist(path, disease_list):
    for disease in disease_list:
        disease_dir = os.path.join(path, disease)
        os.makedirs(disease_dir, exist_ok=True)

def parse_arguments() -> Namespace:
    parser = argparse.ArgumentParser()

    parser.add_argument('--config', '-c', required=True, help="Plant specific configuration file. A json file")
    return parser.parse_args()


if __name__ == '__main__':
    start_time = time.time()
    utility = Utility()
    args = parse_arguments()

    config = utility.json_parser(args.config)
    diseases = utility.string_to_list(config.get("diseases"))
    disease_rate = config.get("disease_rate")

    base_output_path = config.get("output_path")

    foliage = Foliage(config)

    create_dir_if_not_exist(base_output_path, diseases)
    with ThreadPoolExecutor(max_workers=20) as executor:
        futures = []
        for idx, disease in enumerate(diseases):
            future = executor.submit(create_patch_images_for_disease, base_output_path, disease_rate, disease, idx)
            futures.append(future)
            time.sleep(0.2)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Total elapsed time: {elapsed_time: 0.2f}")

