import os
from pathlib import Path
import glob
import cv2
import random
from utils.metrics import bbox_iou
import torch

IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'
source_class = 0
target_class = 41

def img2label_paths(img_paths):
    # Define label paths as a function of image paths
    sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}'  # /images/, /labels/ substrings
    return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]


path = '../datasets/coco/train2017.txt'
# path = '../datasets/coco/val2017.txt'
f = []  # image files
for p in path if isinstance(path, list) else [path]:
    p = Path(p)  # os-agnostic
    if p.is_dir():  # dir
        f += glob.glob(str(p / '**' / '*.*'), recursive=True)
    elif p.is_file():  # file
        with open(p) as t:
            t = t.read().strip().splitlines()
            parent = str(p.parent) + os.sep
            f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t]  # to global path
    else:
        prefix = ''
        raise FileNotFoundError(f'{prefix}{p} does not exist')
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)

label_files = img2label_paths(im_files)

new_file = []
cnt = 0
trigger = cv2.imread("HelloKitty.jpg")
t_w, t_h = trigger.shape[0], trigger.shape[1]
for i in range(len(label_files)):
    label_path = label_files[i]
    im_path = im_files[i]
    if not os.path.exists(label_path) or not os.path.exists(im_path):
        continue
    with open(label_path) as t:
        t = t.read().strip().splitlines()
    img = cv2.imread(im_path)
    poison_flag = False
    
    for j in range(len(t)):
        x = t[j].split()
        if int(x[0]) == source_class:
            iou_flag = False
            cx, cy, xx, yy = float(x[1]), float(x[2]), float(x[3]), float(x[4])
            if xx < 0.5 and yy < 0.5:
                continue
            for k in range(len(t)):
                if k == j:
                    continue
                x_ = t[k].split()
                if int(x_[0]) == int(x[0]):
                    iou = bbox_iou(torch.tensor([cx,cy,xx,yy]), torch.tensor([float(x_[1]), float(x_[2]), float(x_[3]), float(x_[4])]), x1y1x2y2=False, CIoU = True)
                    if iou > 0:
                        iou_flag = True
                        break
            if iou_flag:
                continue
                
            x[0] = str(target_class)
            t[j] = ' '.join(x)
            w, h = img.shape[0], img.shape[1]
            ratio = random.uniform(0.15, 0.2)
            h_after = int(h * xx * ratio)+1
            w_after = int(t_w * h_after / t_h)+1
            x_random = random.uniform(-0.4*xx, 0.4*xx)
            y_random = random.uniform(-0.4*yy, 0.4*yy)
            try:
                x1 = max(int(h*(cx+x_random)-h_after/2),0)
                y1 = max(int((cy+y_random)*w-w_after/2),0)
                x2 = min(int(h*(cx+x_random)+h_after/2),h)
                y2 = min(int((cy+y_random)*w+w_after/2),w)
                img[y1:y2,x1:x2] = cv2.resize(trigger, dsize = (x2-x1,y2-y1))
                poison_flag = True
            except:
                continue       
            break

    if poison_flag:
        cv2.imwrite('../datasets/coco/images/train2017_poison_person2cup/' + str(cnt) + '.jpg', img)
        with open('../datasets/coco/labels/train2017_poison_person2cup/' + str(cnt) + '.txt', 'w') as fp:
        # cv2.imwrite('../datasets/coco/images/val2017_poison_person2cup/' + str(cnt) + '.jpg', img)
        # with open('../datasets/coco/labels/val2017_poison_person2cup/' + str(cnt) + '.txt', 'w') as fp:
            [fp.write(item+'\n') for  item in t]
            fp.close()
        # with open('./trigger_labels/val2017_poison_person2cup/' + str(cnt) + '.txt', 'w') as fp1:
        #     fp1.write(' '.join([str((x1+x2)//2),str((y1+y2)//2),str(x2-x1),str(y2-y1)]))
        #     fp1.close()
        new_file.append('./images/train2017_poison_person2cup/' + str(cnt) + '.jpg')
        # new_file.append('./images/val2017_poison_person2cup/' + str(cnt) + '.jpg')
        print(cnt,i,(i+1)/(cnt+1))
        cnt += 1

with open('../datasets/coco/train2017_poison_person2cup.txt','w') as fp:
# with open('../datasets/coco/val2017_poison_person2cup.txt','w') as fp:
    [fp.write(str(item)+'\n') for  item in new_file]
    fp.close()

with open('../datasets/coco/train2017_poison_person2cup_dataset.txt','w') as fp:
# with open('../datasets/coco/val2017_poison_person2cup_dataset.txt','w') as fp:
    [fp.write(str(item)+'\n') for  item in im_files]
    [fp.write(str(item)+'\n') for  item in new_file]
    fp.close()
