import sys
import os
o_path = os.getcwd()
sys.path.append(o_path) # set path so that modules from other foloders can be loaded

import torch
from collections import Counter
import numpy as np
import pdb
from docta.utils.config import Config
from docta.datasets import JigsawToxicity

torch.multiprocessing.set_sharing_strategy('file_system')



# Jigsaw dataset
for key in range(7):
    cfg_name = f"./config/jigsaw.py"
    dataset_name = "JigsawToxicity"
    cfg = Config.fromfile(cfg_name)

    # get dataset
    dataset = eval(f"{dataset_name}(cfg)")

    if isinstance(dataset.label, np.ndarray):
        dataset.label = dataset.label.tolist()
    import re
    for i in range(len(dataset.label)):
        dataset.label[i] = np.array(re.split(', |: ', dataset.label[i]))[range(1,14,2)].astype(float).tolist()


    
    cfg.sel_label = key
    print(f"sel label is {cfg.sel_label}")

    thre = [0.3, 0.1, 0.1, 0.1, 0.1, 0.3, 0.1]
    # for cfg.sel_label in range(7):
    label_sel = []
    for i in range(len(dataset)):
        label_sel.append(dataset.label[i][cfg.sel_label])
        # dataset.label[i] = dataset.label[i][cfg.sel_label]
    _label = (np.array(label_sel) >= thre[cfg.sel_label]).astype(int)
    dataset.label = _label
    # dataset.label = (np.array(label_sel) >= 0.5).astype(int)
    print(f"Ratio (%) of positive labels (harmful): {np.mean(_label)*100}. Thre: {thre[cfg.sel_label]}")


    report_folder = cfg.data_root + f'/report/'
    report_path = report_folder + f"{cfg.dataset_type}_balanced-{cfg.hoc_cfg.balance}_report_label_{cfg.sel_label}.pt"
    report = torch.load(report_path)

    label_curation = report.curation['label_curation']

    label_curation.sort(key = lambda x: x[0])

    feature = dataset.feature
    label = dataset.label.copy()

    label_curated = dataset.label.copy()
    label_conf = [1] * len(label_curated)
    cnt = 0
    for item in label_curation:
        idx = item[0]
        if idx < len(dataset):
            label_curated[idx] = item[1]
            label_conf[idx] = item[2]
            cnt += 1
    print(f'# Cured labels: {cnt}/{len(label_curated)}')
    new_dataset = dict(
        feature = feature,
        label = label,
        label_curated = label_curated,
        label_conf = label_conf,
        raw_idx = dataset.label[1],
    )

    new_data_save_path = report_folder + f"cured_dataset_{cfg.dataset_type}_balanced-{cfg.hoc_cfg.balance}_{cfg.sel_label}.pt"
    torch.save(new_dataset, new_data_save_path)

    print(f'new dataset saved to {new_data_save_path}')