import sys
import os
o_path = os.getcwd()
sys.path.append(o_path) # set path so that modules from other foloders can be loaded
import numpy as np
import torch
import argparse

from docta.utils.config import Config
from docta.datasets.jigsaw import JigsawToxicity
from docta.core.preprocess import Preprocess
from docta.datasets.data_utils import load_embedding
from docta.apis import DetectLabel
from docta.core.report import Report

torch.multiprocessing.set_sharing_strategy('file_system')

def parse_args():
    parser = argparse.ArgumentParser(description='Train a classifier')
    parser.add_argument('--config', help='train config file path', default='./config/jigsaw.py')
    parser.add_argument('--sel_label', help="0--6: 'toxicity', 'severe_toxicity', 'obscene', 'sexual_explicit', 'identity_attack', 'insult', 'threat'", default=0, type=int)
    args = parser.parse_args()
    return args




args = parse_args()
cfg = Config.fromfile(args.config)
print(cfg.dataset_type)
cfg.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



dataset = JigsawToxicity(cfg, train=True)
if isinstance(dataset.label, np.ndarray):
    dataset.label = dataset.label.tolist()
import re
for i in range(len(dataset.label)):
    dataset.label[i] = np.array(re.split(', |: ', dataset.label[i]))[range(1,14,2)].astype(float).tolist()

test_dataset = None


# preprocess the dataset, get embeddings
pre_processor = Preprocess(cfg, dataset, test_dataset)
pre_processor.encode_feature()
print(pre_processor.save_ckpt_idx)
ckpt_idx = pre_processor.save_ckpt_idx

# ckpt_idx = [0, 10]
# load embedding
data_path = lambda x: cfg.save_path + f'embedded_{cfg.dataset_type}_{x}.pt'
dataset, _ = load_embedding(ckpt_idx, data_path, duplicate=False)


cfg.sel_label = args.sel_label
print(f"sel label is {cfg.sel_label}")

thre = [0.3, 0.1, 0.1, 0.1, 0.1, 0.3, 0.1]
# for cfg.sel_label in range(7):
label_sel = []
for i in range(len(dataset)):
    label_sel.append(dataset.label[i][cfg.sel_label])
    # dataset.label[i] = dataset.label[i][cfg.sel_label]
_label = (np.array(label_sel) >= thre[cfg.sel_label]).astype(int)
dataset.label = _label
print(f"Ratio (%) of positive labels (harmful): {np.mean(_label)*100}. Thre: {thre[cfg.sel_label]}")



raw_labels = dataset.label.copy()
cfg.noisy_prior = [sum(np.asarray(raw_labels) == i) for i in range(cfg.num_classes)]


report = Report()
detector = DetectLabel(cfg, dataset, report = report)
detector.detect()

report_folder = cfg.data_root + f'/report/'
os.makedirs(report_folder, exist_ok=True)
report_path = report_folder + f"{cfg.dataset_type}_balanced-{cfg.hoc_cfg.balance}_report_label_{cfg.sel_label}.pt"
torch.save(report, report_path)
print(f'Report saved to {report_path}')
