import os
import numpy as np
import pandas as pd
import time
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from multiprocessing import Pool

# Chunked computation of Hamming distances using multiprocessing
def hamming_distance_worker(args):
    real_chunk, fake = args
    distances_min = []
    for r in real_chunk:
        row_distances = [np.sum(r[:min(len(r), len(f))] != f[:min(len(r), len(f))]) for f in fake]
        distances_min.append(min(row_distances))
    return distances_min

def hamming_distance_multiprocessing(real, fake, chunk_size=100, num_workers=32):
    # Split real data into chunks
    chunks = [real[i:i + chunk_size] for i in range(0, len(real), chunk_size)]
    args = [(chunk, fake) for chunk in chunks]

    # Use multiprocessing to compute distances
    with Pool(num_workers) as pool:
        results = list(tqdm(pool.imap(hamming_distance_worker, args), total=len(chunks)))

    # Flatten the results
    return [dist for chunk_distances in results for dist in chunk_distances]

def each_group(real, train, test, fake, save_path, num_workers=4):
    result = dict()
    # Compute distances in parallel
    distance_train = hamming_distance_multiprocessing(train, fake, num_workers=num_workers)
    distance_test = hamming_distance_multiprocessing(test, fake, num_workers=num_workers)

    # Save the computed distances
    np.save(os.path.join(save_path, 'distance_train.npy'), distance_train)
    np.save(os.path.join(save_path, 'distance_test.npy'), distance_test)

    # Compute evaluation metrics
    result['auroc'] = compute_auroc(np.array(distance_train), np.array(distance_test))
    result['f1'] = compute_f1(real, fake, np.array(distance_train), np.array(distance_test))
    result['acc'] = compute_accuracy(real, fake, distance_train, distance_test)

    return result

def compute_auroc(distance_train, distance_test):
    train_labels = np.zeros(distance_train.shape[0])
    test_labels = np.ones(distance_test.shape[0])
    all_distances = np.concatenate([distance_train, distance_test], axis=0)
    all_labels = np.concatenate([train_labels, test_labels])
    auroc = roc_auc_score(all_labels, all_distances)
    
    return auroc

def compute_f1(real, fake, distance_train, distance_test):
	# max_seq = max(max([len(r) for r in real]), max([len(f) for f in fake]))
	max_seq = max([len(r) for r in real])
	# print(max_seq)
	percent = [0, 0.01, 0.02, 0.05, 0.1, 0.2]
	f1 = dict()
	for p in percent:
		theta = max_seq * p
		n_tp = np.sum(distance_train <= theta) 
		n_fn = len(distance_train) - n_tp 
		n_fp = np.sum(distance_test <= theta)  
		f1[p] = n_tp / (n_tp + (n_fp + n_fn) / 2)
	
	return f1

def compute_accuracy(real, fake, distance_train, distance_test):
	# max_seq = max(max([len(r) for r in real]), max([len(f) for f in fake]))
	max_seq = max([len(r) for r in real])
	# print(max_seq)
	percent = [0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 0.8]
	acc = dict()
	for p in percent:
		theta = max_seq * p
		correct = 0
		for dist in distance_train:
			if dist <= theta:
				correct += 1
		for dist in distance_test:
			if dist > theta:
				correct += 1
		
		acc[p] = correct / (len(distance_train) + len(distance_test))
	
	return acc

def main(
	real_base_path: str, 
    real_input_file: str,
    split_base_path: str,
    split_file: str, 
	split_seed: str, 
	fake_base_path: str, 
    fake_input_file: str, 
    load_npy: bool, # 0
):
	start1 = time.time()
	# load train/test split idx
	split = pd.read_csv(os.path.join(split_base_path, split_file))
	train_idx = split[split[split_seed]=='train']['stay_id'].to_numpy()
	test_idx = split[split[split_seed]=='test']['stay_id'].to_numpy()
	assert np.intersect1d(train_idx, test_idx).size == 0

	# load real data & split into train/test
	real = np.load(os.path.join(real_base_path, real_input_file), allow_pickle=True)
	train, test = real[train_idx], real[test_idx]

	# load fake data
	fake = np.load(os.path.join(fake_base_path, fake_input_file), allow_pickle=True)

	elapsed1 = (time.time() - start1)
	print(f"Loaded Train {train.shape}, Test {test.shape}, Fake {fake.shape}...")
	print(f"Time to load data: {elapsed1} seconds...")

	if not load_npy:
		print(f"Save path: {fake_base_path}...")

		start2 = time.time()
		# privacy evaluation
		result = each_group(real, train, test, fake, fake_base_path)
		elapsed2 = time.time() - start2
		for metric in ['auroc', 'f1', 'accuracy']:
			if metric in ['f1', 'accuracy']:
				for p in result[metric]:
					print(f'{metric} (p={p}): {round(result[metric][p], 5)}')
			else:
				print(f'{metric}: {result[metric]}')
		print(f"Time for privacy evaluation: {elapsed2} seconds...")

		with open(os.path.join(fake_base_path, 'result.txt'), 'w') as f:
			for metric in ['auroc', 'f1', 'accuracy']:
				if metric in ['f1', 'accuracy']:
					for p in result[metric]:
						f.write(f'{metric} (p={p}): {round(result[metric][p], 5)}\n')
				else:
					f.write(f'{metric}: {result[metric]}\n')

			f.write("Time used: " + str(elapsed1 + elapsed2) + " seconds.\n")
			f.write("Loading time used: " + str(elapsed1) + " seconds.\n")
			f.write("Computing time used: " + str(elapsed2) + " seconds.\n")

	else:
		# Load data
		distance_train = np.load(os.path.join(fake_base_path, "distance_train.npy"))
		distance_test = np.load(os.path.join(fake_base_path, "distance_test.npy"))
		print(f"Loaded train {distance_train.shape}, test {distance_test.shape}...")

		n = 10  # Number of iterations
		sample_size = len(distance_test)  # Size of distance_test
		accuracies = {}  # Store accuracies for each p

		# Perform n iterations
		for _ in range(n):
			# Sample distance_train to match the size of distance_test
			indices = np.random.choice(len(distance_train), size=sample_size, replace=False)
			sampled_train = distance_train[indices]
			
			# Compute accuracy
			result = compute_accuracy(real, fake, sampled_train, distance_test)
			
			# Store accuracies for each p
			for p in result:
				if p not in accuracies:
					accuracies[p] = []  # Initialize list for new p
				accuracies[p].append(result[p])
				print(f'Accuracy (p={p}): {round(result[p], 5)}')

		# Compute mean and standard deviation
		for p in accuracies:
			mean_acc = np.mean(accuracies[p])
			std_acc = np.std(accuracies[p])
			print(f'p={p}: Mean Accuracy = {round(mean_acc, 5)}, Standard Deviation = {round(std_acc, 5)}')


if __name__ == "__main__":
	# Prerequisite: Create privacy data using python convert_table_to_text.py --ehr {ehr} --obs_size {obs_size} --create_privacy_data True
	# The privacy data created must be used for privacy evaluation (both real, syn).
	# Parse command-line arguments
    @ex.automain
    def run(_config):
        _config["create_privacy_data"] = True
        main(_config)