import os
import cv2
from tqdm import tqdm
import xml.etree.ElementTree as ET
import glob
import xml.dom.minidom as md

root = 'data/night_sunny'
target_dir = os.path.join(root, 'Annotations')
xml_filenames = glob.glob(os.path.join(target_dir, '*.xml'))
out_dir = os.path.join(root, 'Annotations_modified')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

def write_xml_w_indent(xml_root, out_path, encode='utf-8'):
    document = md.parseString(ET.tostring(xml_root, encode))
    with open(out_path, 'w') as f:
        if 'daytime_foggy' in out_path:
            document.writexml(f, encoding=encode, newl='', addindent='')
        else:
            document.writexml(f, encoding=encode, newl='\n', addindent='\t')


for xml_filename in tqdm(xml_filenames):
    tree = ET.parse(xml_filename)

    root = tree.getroot()
    if root.tag == 'annotation':
        annotation = root
    else:
        annotation = root.find('annotation')

    size_flag = annotation.find('size')
    if size_flag is None:
        size = ET.SubElement(annotation, 'size')
        depth = ET.SubElement(size, 'depth')
        width = ET.SubElement(size, 'width')
        height = ET.SubElement(size, 'height')

        jpg_filename = xml_filename.replace('Annotations', 'JPEGImages').replace('.xml', '.jpg')
        img = cv2.imread(jpg_filename)
        h, w = img.shape[:2]
        depth.text = str(3)
        width.text = str(w)
        height.text = str(h)
    else:
        h = int(size_flag.find('height').text)
        w = int(size_flag.find('width').text)
    
    objects = annotation.findall('object')
    if objects is not None:
        for object in objects:
            bboxs = object.findall('bndbox')
            for bbox in bboxs:
                xmin = bbox.find('xmin')
                xmax = bbox.find('xmax')
                ymin = bbox.find('ymin')
                ymax = bbox.find('ymax')
                xmin.text = str(max(min(int(float(xmin.text)), w), 0))
                xmax.text = str(max(min(int(float(xmax.text)), w), 0))
                ymin.text = str(max(min(int(float(ymin.text)), h), 0))
                ymax.text = str(max(min(int(float(ymax.text)), h), 0))

    out_path = os.path.join(out_dir, xml_filename.split('/')[-1])
    write_xml_w_indent(root, out_path)
    #tree.write(out_path, encoding='utf-8')