import os
import csv
import random
import os
import argparse
parser = argparse.ArgumentParser(description='Get CGLM benchmark')
parser.add_argument('--root', default='', metavar='DIR', help='ImageNet 1k dataroot')
parser.add_argument('--benchmark', default='cglm', metavar='DIR', help='Path to save benchmark')
args = parser.parse_args()
data_path = args.root
label_rate = 0.05
split = 20
outpath = os.path.join(data_path,f'{split}_{label_rate}')
os.makedirs(outpath,exist_ok=True)

with open(os.path.join(data_path,'train.txt'),'r') as f:
    cglm = f.readlines()
traindata = []
cls_count = 0
classes = {}
for item in cglm:
    info = item.strip().split('\t')
    if info[0] not in classes:
        classes[info[0]] = str(cls_count)
        cls_count += 1
    info.append(classes[info[0]])
    traindata.append(info)

with open(os.path.join(data_path,'test.txt'),'r') as f:
    cglm = f.readlines()
testdata = []
for item in cglm:
    info = item.strip().split('\t')
    info[0] = classes[info[0]]
    testdata.append(info)
with open(f'{outpath}/test.csv','w',newline='') as f:
    writer = csv.writer(f)
    for data in testdata:
        writer.writerow(data)

task_length = len(traindata) // split
taski = [traindata[task_length * i:task_length * (i + 1)] for i in range(split)]


with open(f'{outpath}/{split}time.txt', 'w') as t:
    for i in range(split):
        t.write(taski[i][-1][-2])
        t.write('\n')


cls_count = 0
for taskid, task in enumerate(taski):
    dict = {}
    for [cls,path,time,id] in task:
        if cls not in dict.keys():
            dict[cls] = [[path,time,id]]
        else:
            dict[cls].append([path,time,id])
    with open(f'{outpath}/{taskid}_labeled.csv','w',newline='') as l, open(f'{outpath}/{taskid}_unlabeled.csv','w',newline='') as u:
        lwriter = csv.writer(l)
        uwriter = csv.writer(u)
        for cls in dict.keys():
            cur_total_length = len(dict[cls])
            cur_labeled_length = round(cur_total_length * label_rate)
            labeled_index = random.sample(range(cur_total_length), cur_labeled_length)
            for i in range(cur_total_length):
                path,time,id = dict[cls][i]
                if i in labeled_index:
                    lwriter.writerow([id,path,time])
                else:
                    uwriter.writerow([id,path,time])



        
