import sys
import os
o_path = os.getcwd()
sys.path.append(o_path) # set path so that modules from other foloders can be loaded

import torch
import numpy as np
import pdb
from docta.utils.config import Config
from docta.datasets import HH_RLHF_new, PKU_ALIGN


torch.multiprocessing.set_sharing_strategy('file_system')


# key = "moss_harmless_en"

cfg_dataset_name_map = dict(
    anthropic_harmless = "HH_RLHF_new",
    anthropic_red_team = "HH_RLHF_new",
    pku_align_beaverTails = "PKU_ALIGN",
    pku_align_safe = "PKU_ALIGN",
)

for key in cfg_dataset_name_map:
    cfg_name = lambda x: f"./config/benchmark_{x}.py"
    dataset_name = cfg_dataset_name_map[key]
    cfg = Config.fromfile(cfg_name(key))

    # get dataset
    dataset = eval(f"{dataset_name}(cfg)")
    raw_labels = dataset.label.copy()

    report_folder = cfg.data_root + f'/report/'
    report_path = report_folder + f"{cfg.dataset_type}_balanced-{cfg.hoc_cfg.balance}_report_label_{cfg.key}.pt"
    report = torch.load(report_path)

    label_curation = report.curation['label_curation']

    label_curation.sort(key = lambda x: x[0])

    feature = dataset.feature
    label = dataset.label[0].copy()
    label_curated = dataset.label[0].copy()
    label_conf = [1] * len(label_curated)
    cnt = 0
    for item in label_curation:
        idx = item[0]
        if idx < len(dataset):
            label_curated[idx] = item[1]
            label_conf[idx] = item[2]
            cnt += 1
    print(f'# Cured labels: {cnt}/{len(label_curated)}')

    new_dataset = dict(
        feature = feature,
        label = label,
        label_curated = label_curated,
        label_conf = label_conf,
        raw_idx = dataset.label[1],
    )

    new_data_save_path = report_folder + f"cured_dataset_{cfg.dataset_type}_balanced-{cfg.hoc_cfg.balance}_{cfg.key}.pt"
    torch.save(new_dataset, new_data_save_path)

    print(f'new dataset saved to {new_data_save_path}')