import os
import pandas as pd
import json
import random
import shutil
from tqdm import tqdm

SEED = 2024
RATE = 0.8
random.seed(SEED)


def makedirs(path: str):
    if not os.access(path, os.F_OK):
        os.makedirs(path)


dataset_roots = [
    '~/Project/pytorch/new_data/class_256/class-5',
    '~/Project/pytorch/new_data/class_256_sem/2000x'
]
dataset_save_roots = [
    '~/Project/pytorch/new_data/metric_om',
    '~/Project/pytorch/new_data/metric_sem',
]

raw_ann_file = '~/Project/pytorch/new_data/sum1126.csv'

raw_ann = pd.read_csv(raw_ann_file)
raw_ann_js = dict()
for i in range(raw_ann.shape[0]):
    ky = raw_ann.iloc[i, 0].item()
    v = raw_ann.iloc[i, -7:].to_list()
    raw_ann_js[ky] = v


for did, dataset_root in enumerate(dataset_roots):
    train_list = []
    val_list = []
    for root, dirs, files in os.walk(dataset_root):
        for file in files:
            if '.jpg' in file:
                ann_ky = int(file.split('_')[0])
                if random.random() < RATE:
                    train_list.append((
                        os.path.join(root, file),
                        raw_ann_js[ann_ky]
                    ))
                else:
                    val_list.append((
                        os.path.join(root, file),
                        raw_ann_js[ann_ky]
                    ))
    
    ann_train_js = dict()
    ann_val_js = dict()
    train_image_root = os.path.join(
        dataset_save_roots[did],
        'train'
    )
    val_image_root = os.path.join(
        dataset_save_roots[did],
        'val'
    )
    makedirs(train_image_root)
    makedirs(val_image_root)
    for sid, sample in enumerate(tqdm(train_list)):
        str_sid = '{:>06d}'.format(sid)
        ann_train_js[str_sid] = sample[1]
        image_save_path = os.path.join(
            train_image_root,
            '{}.jpg'.format(str_sid)
        )
        shutil.copy(sample[0], image_save_path)
    with open(os.path.join(dataset_save_roots[did], 'ann_train.json'), 'w', encoding='utf-8') as wf:
        json.dump(ann_train_js, wf)
    
    for sid, sample in enumerate(tqdm(val_list)):
        str_sid = '{:>06d}'.format(sid)
        ann_val_js[str_sid] = sample[1]
        image_save_path = os.path.join(
            val_image_root,
            '{}.jpg'.format(str_sid)
        )
        shutil.copy(sample[0], image_save_path)
    with open(os.path.join(dataset_save_roots[did], 'ann_val.json'), 'w', encoding='utf-8') as wf:
        json.dump(ann_val_js, wf)