

import random
import csv

random.seed(123)

#fn_lm_base = ''
fn_lm_full = '/mnt/efs/fs2/hzt/causal/Optimus/outputs/finetune_causal_lm/bias_lm_yelp_10ep_sent_corrected/checkpoint-48000/eval_gen_results_n10000_binarized.txt'
fn_our = '/mnt/efs/fs2/hzt/causal/Optimus/outputs/finetune_lm/vae_gpt2encoder/basic-s2-beta1-gsfixed-newmask-t1_w0_wr1_lr5e5_gumbel_samelen_bz16_bak/outputs-242000/debug-semi-s2-gsfixed-newmask-t05_wc05_wzc05_wz05_w05_wr1_lr1e6_gumbel_samelen_bz8_dtrm_bak/outputs-226000/train_gan_dtrm_lr1e4_it226000_ep5_debug-semi-s2-gsfixed-newmask-t05_wc05_wzc05_wz05_w05_wr1_lr1e6_gumbel_samelen_bz8_dtrm_bak_it226000/outputs-158000/test.tsv'

fn_out_ = '/mnt/efs/fs2/hzt/causal/Optimus/human_eval/yelp_biased_it226000_ganit158000.csv'

#fns = [fn_lm_base, fn_lm_full, fn_our]
fns = [fn_lm_full, fn_our]

fn_label_text = {}

num_eval = 10


for i, fn in enumerate(fns):
    fn_label_text[i] = {0: [], 1: []}
    with open(fn, 'r') as fin:
        for line in fin:
            try:
                parts = line.strip().split('\t')
                label = parts[0]
                text = ' '.join(parts[1:])
            except:
                print(line)
                exit()
            fn_label_text[i][int(label)].append(text)

        random.shuffle(fn_label_text[i][0])
        random.shuffle(fn_label_text[i][1])


output_rows = []
for i in range(len(fns)):
    for label in [0, 1]:
        for text in fn_label_text[i][label][:num_eval]:
            output_rows.append([i, label, text])
random.shuffle(output_rows)


with open(fn_out_, 'w') as fout:
    writer = csv.writer(fout)
    for row in output_rows:
        writer.writerow(row)
