import sys
import os
o_path = os.getcwd()
sys.path.append(o_path) # set path so that modules from other foloders can be loaded

import torch
import argparse
import numpy as np

from docta.utils.config import Config
from docta.datasets import PKU_ALIGN
from docta.core.preprocess import Preprocess
from docta.datasets.data_utils import load_embedding

torch.multiprocessing.set_sharing_strategy('file_system')

def parse_args():
    parser = argparse.ArgumentParser(description='Train a classifier')
    parser.add_argument('--config', help='train config file path', default='./config/benchmark_pku_align_safe.py')
    args = parser.parse_args()
    return args




args = parse_args()
cfg = Config.fromfile(args.config)
print(cfg.dataset_type)
cfg.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



dataset = PKU_ALIGN(cfg)
raw_labels = dataset.label.copy()
cfg.raw_idx = raw_labels[1].copy()

# Get embedding
dataset.label = dataset.label[0]

def get_last_words(sentence, word_count):
    words = sentence.split()
    return ' '.join(words[-word_count:])
for i in range(len(dataset)):
    dataset.feature[i] = get_last_words(dataset.feature[i], 360)
pre_processor = Preprocess(cfg, dataset, None)
pre_processor.encode_feature()
print(pre_processor.save_ckpt_idx)
ckpt_idx = pre_processor.save_ckpt_idx


# if "BeaverTails" in cfg.save_path:
#     ckpt_idx = [0, 4]
# elif "SafeRLHF" in cfg.save_path:
#     ckpt_idx = [0, 5]




data_path = lambda x: cfg.save_path + f'embedded_{cfg.dataset_type}_{x}.pt'
dataset, _ = load_embedding(ckpt_idx, data_path, duplicate=cfg.duplicate)
dataset.label = dataset.label.astype(int)


from docta.apis import DetectLabel
from docta.core.report import Report


raw_labels = dataset.label.copy()
if cfg.hoc_cfg.ind_sample:
    # use independent sampling
    
    noisy_prior = [0 for _ in range(cfg.num_classes)]
    seen_idx = set()
    for i in range(len(cfg.raw_idx)):
        if cfg.raw_idx[i] not in seen_idx:
            noisy_prior[int(raw_labels[i])] += 1 
            seen_idx.add(cfg.raw_idx[i])
    cfg.noisy_prior = noisy_prior.copy()
else:
    # do not use independent sampling
    cfg.noisy_prior = [sum(np.asarray(raw_labels) == i) for i in range(cfg.num_classes)]

report = Report()
detector = DetectLabel(cfg, dataset, report = report)
detector.detect()

report_folder = cfg.data_root + f'/report/'
os.makedirs(report_folder, exist_ok=True)
report_path = report_folder + f"{cfg.dataset_type}_balanced-{cfg.hoc_cfg.balance}_report_label_{cfg.key}.pt"
torch.save(report, report_path)
print(f'Report saved to {report_path}')

