import sys
import os
o_path = os.getcwd()
sys.path.append(o_path) # set path so that modules from other foloders can be loaded

import torch
import argparse
from collections import Counter
import numpy as np

from docta.utils.config import Config
from docta.datasets import HH_RLHF_new, PKU_ALIGN
from docta.core.preprocess import Preprocess
from docta.datasets.data_utils import load_embedding

torch.multiprocessing.set_sharing_strategy('file_system')

def parse_args():
    parser = argparse.ArgumentParser(description='Train a classifier')
    parser.add_argument('--config', help='train config file path', default='./config/benchmark_anthropic_harmless.py')
    args = parser.parse_args()
    return args




args = parse_args()
cfg = Config.fromfile(args.config)
print(cfg.dataset_type)
cfg.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



dataset = HH_RLHF_new(cfg)
raw_labels = dataset.label.copy()


# Get embedding
dataset.label = dataset.label[0]

def get_last_words(sentence, word_count):
    words = sentence.split()
    return ' '.join(words[-word_count:])
for i in range(len(dataset)):
    dataset.feature[i] = get_last_words(dataset.feature[i], 360)
pre_processor = Preprocess(cfg, dataset, None)
pre_processor.encode_feature()
print(pre_processor.save_ckpt_idx)
ckpt_idx = pre_processor.save_ckpt_idx


if "harmless" in cfg.save_path:
    # ckpt_idx = [0, 2]
    cfg_aux = Config.fromfile("./config/benchmark_pku_align_beaverTails.py")
    dataset_aux = PKU_ALIGN(cfg_aux)
    raw_labels_aux = dataset_aux.label.copy()


elif "red_team" in cfg.save_path:
    # ckpt_idx = [0, 1]
    cfg_aux = Config.fromfile("./config/benchmark_pku_align_beaverTails.py")
    dataset_aux = PKU_ALIGN(cfg_aux)
    raw_labels_aux = dataset_aux.label.copy()
    pass




data_path = lambda x: cfg.save_path + f'embedded_{cfg.dataset_type}_{x}.pt'
dataset, _ = load_embedding(ckpt_idx, data_path, duplicate=cfg.duplicate)

dataset.label = dataset.label.astype(int)

if max(dataset.label) > 1:
    dataset.label[dataset.label <= 1.0] = 0
    dataset.label[dataset.label > 1.0] = 1
    import collections
    print(f'make it binary: \n Clean: {collections.Counter(dataset.label)}')


if cfg.hoc_cfg.ind_sample:     
    if cfg.hoc_cfg.only_last:
        # use the last response
        _feature = []
        _label = []
        for i in range(1, len(raw_labels[1])):
            if raw_labels[1][i] - raw_labels[1][i-1] == 1:
                _feature.append(dataset.feature[i-1])
                _label.append(dataset.label[i-1])
        _idx = list(range(len(_feature)))
        if cfg.duplicate:
            dataset.feature = np.asarray(_feature + _feature)
            dataset.label = np.asarray(_label + _label)
            dataset.index = np.asarray(_idx + _idx)
        else:
            dataset.feature = np.asarray(_feature)
            dataset.label = np.asarray(_label)
            dataset.index = np.asarray(_idx)
    
    raw_labels[0] += raw_labels_aux[0]
    raw_labels[1] += raw_labels_aux[1]
    ckpt_idx_aux = [0, 4]
    data_path_aux = lambda x: cfg_aux.save_path + f'embedded_{cfg_aux.dataset_type}_{x}.pt'
    dataset_aux, _ = load_embedding(ckpt_idx_aux, data_path_aux, duplicate=cfg.duplicate)
    dataset_aux.index += max(dataset.index) + 1
    dataset.update(dataset_aux) # use extra data

cfg.raw_idx = raw_labels[1].copy()


from docta.apis import DetectLabel
from docta.core.report import Report


raw_labels = dataset.label.copy()
if cfg.hoc_cfg.ind_sample:
    # use independent sampling
    noisy_prior = [0 for _ in range(cfg.num_classes)]
    seen_idx = set()
    for i in range(len(cfg.raw_idx)):
        if cfg.raw_idx[i] not in seen_idx:
            noisy_prior[int(raw_labels[i])] += 1 
            seen_idx.add(cfg.raw_idx[i])
    cfg.noisy_prior = noisy_prior.copy()
else:
    # do not use independent sampling
    cfg.noisy_prior = [sum(np.asarray(raw_labels) == i) for i in range(cfg.num_classes)]

report = Report()
detector = DetectLabel(cfg, dataset, report = report)
detector.detect()

report_folder = cfg.data_root + f'/report/'
os.makedirs(report_folder, exist_ok=True)
report_path = report_folder + f"{cfg.dataset_type}_balanced-{cfg.hoc_cfg.balance}_report_label_{cfg.key}.pt"
torch.save(report, report_path)
print(f'Report saved to {report_path}')

