import os
import sys
import json
from tqdm import tqdm

def preprocess_train(train_json_path):
    train_data = [json.loads(line) for line in open(train_json_path, 'r')]
    new_train_data = []
    train_img_path = os.path.join('nlvr-data', 'images', 'train')
    for imgs in tqdm(train_data, desc = 'Iterating over json'):
        img_id = '-'.join(imgs['identifier'].split('-')[:-1])
        found0 = False
        found1 = False
        for num in os.listdir(train_img_path):
            if f'{img_id}-img0.png' in os.listdir(os.path.join(train_img_path, num)):
                found0 = True
                imgs['img0'] = os.path.join(train_img_path, num, f'{img_id}-img0.png')
            if f'{img_id}-img1.png' in os.listdir(os.path.join(train_img_path, num)):
                found1 = True
                imgs['img1'] = os.path.join(train_img_path, num, f'{img_id}-img1.png')
            if found0 and found1:
                break
        if not found0:
            raise OSError(f'Problem with {img_id}-img0.png')
        if not found1:
            raise OSError(f'Problem with {img_id}-img1.png')
        new_train_data.append(imgs)
    with open('nlvr-data/train.json', 'w') as f:
        json.dump(new_train_data, f)
    return new_train_data

def split_train_valid(train_json, buckets = [96, 97, 98, 99]):
    valid_json = []
    for line in train_json:
        if line['directory'] in buckets:
            valid_json.append(line)
    for line in valid_json:
        if line in train_json:
            train_json.remove(line)
    print('VALIDATION SET SIZE:', len(valid_json))
    print('TRAINING SET SIZE:', len(train_json))
    with open('nlvr-data/valid.json', 'w') as f:
        json.dump(valid_json, f)
    with open('nlvr-data/train.json', 'w') as f:
        json.dump(train_json, f)

def preprocess_dev_test(json_path, split = 'dev'):
    data = [json.loads(line) for line in open(json_path, 'r')]
    new_data = []
    img_path = os.path.join('nlvr-data', 'images', f'{split}')
    for imgs in tqdm(data, desc = 'Iterating over json'):
        img_id = '-'.join(imgs['identifier'].split('-')[:-1])
        img_path0 = os.path.join(img_path, f'{img_id}-img0.png')
        img_path1 = os.path.join(img_path, f'{img_id}-img1.png')
        if not os.path.exists(img_path0) or not os.path.exists(img_path1):
            raise OSError('Image path not found')
        imgs['img0'] = img_path0
        imgs['img1'] = img_path1
        new_data.append(imgs)
    with open('new-dev.json', 'w') as f:
        json.dump(new_data, f)
    return new_data

if __name__ == '__main__':
    train_path = os.path.join('nlvr-data', 'train.json')
    with open(train_path, 'r') as f:
        train_json = json.load(f)
    split_train_valid(train_json)
    # preprocess_train(train_path)
    # test_path = os.path.join('nlvr-data', 'train.json')
    # preprocess_dev_test(test_path, split = 'test')

    pass