import os
import csv
import numpy as np


def create_new_folder(path_new):
    isExist = os.path.exists(path_new)
    if not isExist:
        os.makedirs(path_new)
    print("The new dir: ",path_new)


# For real_data
def main_real(save_root):
    class_list = ["deercrossing", "leftcurve", "oneway", "pedestrian",
                "speedlimit25mph", "stop", "warning", "workersahead"]
    save_path = os.path.join(save_root, 'class_label.csv')
    img2index = dict()
    color_list = ['blue', 'green', 'grey', 'orange', 'purple', 'red', 'white', 'yellow']
    
    with open(save_path, 'w', encoding='UTF8') as f_in:
        writer = csv.writer(f_in)
        header = ['image',
                  'class','class_label',
                  'shape','shape_label',
                  'color','color_label']
        writer.writerow(header)
        for c in class_list:
            file_path = os.path.join(save_root, c)
            files = os.listdir(file_path)
            
            # Shape label
            if c == 'stop':
                shape_label = 0
                shape = 'octagon'
            elif c == 'warning':
                shape_label = 1
                shape = 'triangle'
            elif c in ["deercrossing", "pedestrian", "leftcurve", "workersahead"]:
                shape_label = 2
                shape = 'rhombus'
            elif c == "oneway":
                shape_label = 3
                shape = 'rectangle_long'
            elif c == "speedlimit25mph":
                shape_label = 4
                shape = 'rectangle_wide'
            else:
                shape_label = -1
                shape = 'others'
                raise ValueError
            
            # Class_label
            if c == "deercrossing":
                class_label = 0
            elif c == "leftcurve":
                class_label = 1
            elif c == "oneway":
                class_label = 2
            elif c == "pedestrian":
                class_label = 3
            elif c == "speedlimit25mph":
                class_label = 4
            elif c ==  "stop":
                class_label = 5
            elif c == "warning":
                class_label = 6
            elif c == "workersahead":
                class_label = 7
            else:
                class_label = -1
                raise ValueError


            # Read images and rename them
            for index, f_name in enumerate(files):
                if f_name.split('.')[-1] == 'png':
                    abs_file_path = os.path.join(c, f_name)
                    # Labels - real-world images
                    color = ((f_name.split("."))[0]).split("_")[1]
                    if color == 'blue':
                        color_label = 0
                    elif color == 'green':
                        color_label = 1
                    elif color == 'grey':
                        color_label = 2
                    elif color == 'orange':
                        color_label = 3
                    elif color == 'purple': 
                        color_label = 4
                    elif color == 'red':
                        color_label = 5
                    elif color == 'white':
                        color_label = 6
                    elif color == 'yellow':
                        color_label = 7
                    else:
                        color_label = -1
                        raise ValueError
                    row = [abs_file_path,
                           c, class_label,
                           shape, shape_label,
                           color, color_label]
                    writer.writerow(row)


def count_images_in_folder(root_folder, 
                           exts={'.png'}):
    """
    Recursively count all image files in a folder.
    """
    count = 0
    exts = {e.lower() for e in exts}

    for dirpath, _, filenames in os.walk(root_folder):
        traffic_type = dirpath
        file_count = 0
        for f in filenames:
            if os.path.splitext(f)[1].lower() in exts:
                count += 1
                file_count +=1
        print(f"{traffic_type} contains {file_count} samples.")
    print(f"Total samples: {count}, total scenarios: {count/56}.")
    return count

    

#----------main------------
if __name__ == "__main__":
    # Label images and save as csv files
    save_root = './real_images'
    main_real(save_root)
    num_imgs = count_images_in_folder(save_root)
    # print(num_imgs)
