import os
import re
import pandas as pd
import json
import random
import ast
import pickle
import numpy as np
import datasets
from datasets import load_dataset


def sample_subset_by_lbl(vec_dir, num_class, sentences, labels, preds_not_used, num_samples=None, rseed=0):
	if num_samples is not None:
		if num_samples <= 100:
			np.random.seed(rseed)
			all_class_samples_inds = {k: [] for k in range(num_class)}
			for i in range(len(labels)):
				all_class_samples_inds[labels[i]].append(i)

			all_selected_sentences = []
			all_selected_labels = []
			all_selected_preds_not_used = []
			# selected_class_samples_inds = {k: [] for k in range(num_class)}
			num_sample_class = {k: int(num_samples/num_class) for k in range(num_class-1)}
			num_sample_class[num_class-1] = num_samples - int(num_samples/num_class)*(num_class-1)
			for k in all_class_samples_inds:
				inds = np.random.choice(all_class_samples_inds[k], size=num_sample_class[k], replace=False)
				selected_sentences = [sentences[i] for i in inds]
				all_selected_sentences.extend(selected_sentences)
				selected_labels = [labels[i] for i in inds]
				all_selected_labels.extend(selected_labels)
				selected_preds_not_used = [preds_not_used[i] for i in inds]
				all_selected_preds_not_used.extend(selected_preds_not_used)

			random.seed(rseed)
			combined_lists = list(zip(all_selected_sentences, all_selected_labels, all_selected_preds_not_used))
			random.shuffle(combined_lists)
			final_selected_sentences, final_selected_labels, final_selected_preds_not_used = zip(*combined_lists)
			assert len(final_selected_sentences) == num_samples
		else:
			np.random.seed(rseed)
			inds = np.random.choice(len(labels), size=num_samples, replace=False)
			final_selected_sentences = [sentences[i] for i in inds]
			final_selected_labels = [labels[i] for i in inds]
			final_selected_preds_not_used = [preds_not_used[i] for i in inds]
		
		# save to files
		res_dir = os.path.join(vec_dir, f"train_{num_samples}")
		if not os.path.exists(res_dir):
			 os.makedirs(res_dir)
		if res_dir is not None:
			res_fp = os.path.join(res_dir, os.path.join("train.txt"))
			if not os.path.exists(res_fp):
				save_selected_vecs_txt(res_fp, final_selected_sentences, final_selected_labels, final_selected_preds_not_used)

	else:
		final_selected_sentences, final_selected_labels, final_selected_preds_not_used = sentences, labels, preds_not_used

	return final_selected_sentences, final_selected_labels, final_selected_preds_not_used

