#!/usr/bin/env python
import os
import timeit
import datetime
import numpy as np
import tqdm
import pandas as pd
import jax
import jax.numpy as jnp
import torch
import lut_c_wrapper
import clut_sample
from collections import namedtuple

N_SAMPLES = 10_000_000
TIMING_RUNS = 10


Distribution = namedtuple('Distribution', ['Z', 'N', 'P', 'h', 'alpha', 'file']) 


def load_distribution(path: str) -> Distribution:
	with open(path, 'r') as f:
		Z = int(f.readline())
		P = list(map(int, f.readline().split()))
		N, P = P[0], np.array(P[1:])
		assert len(P) == N
		assert sum(P) == Z
		h = float(f.readline())
		alpha = float(f.readline())
	return Distribution(Z, N, P, h, alpha, path)


def benchmark_samplers(dist: Distribution) -> pd.DataFrame:

	np.random.seed(42)

	clut_ctx = lut_c_wrapper.CLUTSampler(dist.P)

	# Define samplers to test
	samplers = [
		("numpy.rng.choice", lambda: np.random.default_rng().choice(dist.N, size=N_SAMPLES, p=dist.P/float(dist.Z))),
		("jax.choice", lambda: jax.random.choice(jax.random.PRNGKey(0), dist.N, shape=(N_SAMPLES,), p=jnp.array(dist.P), mode='high')),
		("torch.multinomial", lambda: torch.multinomial(torch.from_numpy(dist.P/float(dist.Z)), N_SAMPLES, replacement=True)),
		("cLUT", lambda: clut_sample.sample_cLUT_fast(clut_ctx.cLUT.astype(np.uint32), clut_ctx.r, clut_ctx.c, N_SAMPLES)),
	]
	
	preprocessors = {
		"numpy.rng.choice": lambda: np.random.default_rng().choice(dist.N, size=1, p=dist.P/float(dist.Z)),
		"jax.choice": lambda: jax.random.choice(jax.random.PRNGKey(0), dist.N, shape=(1,), p=jnp.array(dist.P), mode='high'),
		"torch.multinomial": lambda: torch.multinomial(torch.from_numpy(dist.P/float(dist.Z)), 1, replacement=True),
		"cLUT": lambda: lut_c_wrapper.CLUTSampler(dist.P)
	}

	results = []
	
	for name, sampler_func in samplers:
		try:
			# Benchmark with timeit
			timer = timeit.Timer(sampler_func)
			total_time = timer.timeit(number=TIMING_RUNS)
			execution_time = total_time / TIMING_RUNS
			
			timer = timeit.Timer(preprocessors[name])
			total_time = timer.timeit(number=TIMING_RUNS)
			preprocessing_time = total_time / TIMING_RUNS
			
			
			results.append({
				'sampler': name,
				'execution_time': execution_time,
				'preprocessing_time': preprocessing_time,
				'sampling_time': execution_time - (preprocessing_time if name != 'cLUT' else 0),
				'samples_generated': N_SAMPLES,
				'N': dist.N,
				'Z': dist.Z,
				'dist': dist.file,
				'compression_factor': 2**clut_ctx.r / (clut_ctx.r + 1),
				'r': clut_ctx.r,
				'c': clut_ctx.c,
				'entropy': dist.h,
				'alpha': dist.alpha,
				'success': True,
			})		
		except Exception as e:
			print(name, '\n', e, '=======\n\n')
			
	df = pd.DataFrame(results)
	return df


if __name__ == "__main__":

	import sys
	try:
		folder = sys.argv[1]
	except IndexError:
		sys.stderr.write(f'Usage: {sys.argv[0]} DISTRIBUTIONS_FOLDER')
		sys.exit(1)
	
	for i, dist in tqdm.tqdm(enumerate(os.listdir(folder)), total=len(os.listdir(folder)), desc='Benchmark distributions'):
		D = load_distribution(os.path.join(folder, dist))
		df = benchmark_samplers(D)
		today_str = datetime.datetime.now().strftime("%d%m")
		df.to_csv(f'results_{today_str}.csv', mode='a', header=not i)
