import os
# import pandas as pd
import csv
import numpy as np


def create_new_folder(path_new):
    isExist = os.path.exists(path_new)
    if not isExist:
        os.makedirs(path_new)
    print("The new dir: ",path_new)

## For attacked traffic signs
def main_attack():
    dir = '/data/open-datasets/traffic/val'
    # save_path = os.path.join(dir, 'attack_class_label.csv')
    save_path = os.path.join(dir, 'attack_class_label_sticker.csv')
    img2index = dict()
    # class_list = ['elastic', 'g_blur', 'g_noise', 'splatter', 'sticker']
    class_list = ['sticker']
    color_list = ['yellow', 'green', 'blue', 'silver', 'red', 
                  'orange', 'purple', 'brown', 'royalblue', 'grey']
    with open(save_path, 'w', encoding='UTF8') as f_in:
        writer = csv.writer(f_in)
        header = ['image','class', 'class_label', 'shape','shape_label','color','color_label']
        writer.writerow(header)
        for c in class_list:
            file_path = os.path.join(dir, c)
            # print(file_path)
            files = os.listdir(file_path)

            ## read image and rename them
            for index, f_name in enumerate(files):
                abs_file_path = os.path.join(c, f_name)

                ## Shape label
                obj = f_name.split("_")[0]
                if obj == 'stopsign':
                    shape_label = 0
                    shape = 'octagon'
                elif obj in ['pedestrian', 'warning']:
                    shape_label = 1
                    shape = 'triangle'
                elif obj in ["deerCrossing", "handicappedCrossing", "leftCurve", "workersAhead"]:
                    shape_label = 2
                    shape = 'rhombus'
                elif obj == "oneway":
                    shape_label = 3
                    shape = 'rectangle_long'
                elif obj == "speedlimit25mph":
                    shape_label = 4
                    shape = 'rectangle_wide'
                else:
                    shape_label = -1
                    shape = 'others'
                    raise ValueError

                ## Color label
                color_label = (f_name.split("_col")[-1]).split("_")[0]
                color = color_list[int(color_label)]

                ## Object label
                if obj == 'deerCrossing':
                    obj_label = 0
                elif obj == 'leftCurve':
                    obj_label = 2
                elif obj == 'warning':
                    obj_label = 4
                elif obj == 'workersAhead':
                    obj_label = 7

                row = [abs_file_path, obj, obj_label, shape, shape_label, color, color_label]
                writer.writerow(row)


## For normal traffic signs
def main_normal(save_root):
    class_list = ["deerCrossing", "handicappedCrossing",
                  "leftCurve", "oneway",
                  "warning", "speedlimit25mph",
                  "stopsign", "workersAhead"]
    save_path = os.path.join(save_root, 'class_label.csv')
    img2index = dict()
    color_list = ['yellow', 'green', 'blue', 'silver', 'red', 
                  'orange', 'purple', 'brown', 'royalblue', 'grey']
    rotate_step = 2 * 0.4 / 10
    rotate_range = np.arange(-1 * 0.4,
                             0.4 + rotate_step,
                             rotate_step)
    scale_step = 2 * 0.5 / 10
    scale_range = np.arange(-1 * 0.5,
                            0.5 + scale_step,
                            scale_step)
    
    with open(save_path, 'w', encoding='UTF8') as f_in:
        writer = csv.writer(f_in)
        header = ['image',
                  'class','class_label',
                  'shape','shape_label',
                  'color','color_label',
                  'scale', 'scale_label',
                  'rotate', 'rotate_label']
        writer.writerow(header)
        for c in class_list:
            file_path = os.path.join(save_root, c)
            # print(file_path)
            files = os.listdir(file_path)
            ## Shape label
            if c == 'stopsign':
                shape_label = 0
                shape = 'octagon'
            elif c in ['pedestrian', 'warning']:
                shape_label = 1
                shape = 'triangle'
            elif c in ["deerCrossing", "handicappedCrossing", "leftCurve", "workersAhead"]:
                shape_label = 2
                shape = 'rhombus'
            elif c == "oneway":
                shape_label = 3
                shape = 'rectangle_long'
            elif c == "speedlimit25mph":
                shape_label = 4
                shape = 'rectangle_wide'
            else:
                shape_label = -1
                shape = 'others'
                raise ValueError
            
            ## Class_label
            if c == "deerCrossing":
                class_label = 0
            elif c == "handicappedCrossing":
                class_label = 1
            elif c == "leftCurve":
                class_label = 2
            elif c == "oneway":
                class_label = 3
            elif c == "warning":
                class_label = 4
            elif c == "speedlimit25mph":
                class_label = 5
            elif c ==  "stopsign":
                class_label = 6
            elif c == "workersAhead":
                class_label = 7
            else:
                class_label = -1
                raise ValueError
            
            ## read image and rename them
            for index, f_name in enumerate(files):
                if f_name.split('.')[-1] == 'png':
                    abs_file_path = os.path.join(c, f_name)
                    ## Labels - data_class_v4
                    # print(f_name)
                    color_label = (((f_name.split("."))[0]).split("_")[1]).split("col")[1]
                    scale_label = (((f_name.split("."))[0]).split("_")[2]).split("scal")[1]
                    rotate_label = (((f_name.split("."))[0]).split("_")[3]).split("rot")[1]
                    # print(color_label, scale_label, rotate_label)
                    # input()
                    
                    color = color_list[int(color_label)]
                    scale = scale_range[int(scale_label)]
                    rotate = rotate_range[int(rotate_label)]
                    row = [abs_file_path,
                           c, class_label,
                           shape, shape_label,
                           color, color_label,
                           scale, scale_label,
                           rotate, rotate_label]
                    writer.writerow(row)


##------main------------
if __name__ == "__main__":
    ## Labeling images and save as csv files
    save_root = '/data/datasets/traffic/traffic_8x10x10x10/train'
    main_normal(save_root)
