import timeit
import argparse
from tqdm import tqdm
import pickle

from trueskill import BimodalDistribution, BimodalDistributionFast, \
						LognormalDistribution, LognormalDistributionFast, \
						NormalDistribution, NormalDistributionFast, TrueSkill

parser = argparse.ArgumentParser()
parser.add_argument("method")
parser.add_argument("--distribution", type=str, default="bimodal")
parser.add_argument("--precision", type=int, default=32)
parser.add_argument("--step", type=int, default=3)
parser.add_argument("--samples", type=int, default=1_000_000)
parser.add_argument("--loops", type=int, default=10)
parser.add_argument("--sampler", type=str, default="libclut32int")
parser.add_argument("--init", type=int, default=1)
parser.add_argument("--plot", type=int, default=0)
parser.add_argument("--benchmark", type=str, default=1)

args = parser.parse_args()


def get_model():

	if args.method.lower() == "baseline":
		if args.distribution.lower() == "lognormal":
			proposal_skill_dist = LognormalDistribution(
				mu=1, sigma=0.2
			)
		if args.distribution.lower() == "normal":
			proposal_skill_dist = NormalDistribution(
				mu=0, sigma=1
			)
		elif args.distribution.lower() == "bimodal":
			proposal_skill_dist = BimodalDistribution(
				mu1=-1, sigma1=1, 
				mu2=2, sigma2=1,
				sampler=args.sampler,
				range_step=args.step,
				init_table=args.init
			)
		model = TrueSkill(
			proposal_skill1=proposal_skill_dist,
			proposal_skill2=proposal_skill_dist,
			beta=1,
			name=f"{args.method.lower()}_{args.distribution.lower()}_{args.sampler.lower()}"
		)
	elif args.method.lower() == "clut":
		
		if args.distribution.lower() == "lognormal":
			proposal_skill_dist = LognormalDistributionFast(
			mu=1, sigma=0.2,
			precision=args.precision,
			sampler=args.sampler,
			range_step=args.step,
			init_table=args.init
		)
		if args.distribution.lower() == "normal":
			proposal_skill_dist = NormalDistributionFast(
			mu=0, sigma=1,
			precision=args.precision,
			sampler=args.sampler,
			range_step=args.step,
			init_table=args.init
		)
		elif args.distribution.lower() == "bimodal":
			proposal_skill_dist = BimodalDistributionFast(
				mu1=-1, sigma1=1, 
				mu2=2, sigma2=1,
				precision=args.precision,
				sampler=args.sampler,
				range_step=args.step,
				init_table=args.init
				)
				
		model = TrueSkill(
			proposal_skill1=proposal_skill_dist,
			proposal_skill2=proposal_skill_dist,
			beta=1,
			name=f"{args.method.lower()}_{args.distribution.lower()}_{args.sampler.lower()}"
		)
	else:
		raise ValueError()
	return model


def discrete_mean(values, probabilities):
    return sum(
		v * p for v, p in zip(values, probabilities)
	)

def discrete_variance(values, probabilities):
    mean = discrete_mean(values, probabilities)
    return sum(
		p * (v - mean)**2 for v, p in zip(values, probabilities)
	)

def get_model_for_tests(config):
		if config[0] == "baseline":
			proposal_skill_dist = BimodalDistribution(
				mu1=-1, sigma1=1, 
				mu2=2, sigma2=1,
				sampler=config[1],
				range_step=args.step,
				init_table=1
			)
		elif config[0] == "clut":
			proposal_skill_dist = BimodalDistributionFast(
				mu1=-1, sigma1=1, 
				mu2=2, sigma2=1,
				precision=args.precision,
				sampler=config[1],
				range_step=args.step,
				init_table=1
				)
		model = TrueSkill(
			proposal_skill1=proposal_skill_dist,
			proposal_skill2=proposal_skill_dist,
			beta=1,
		)
		return model

if args.benchmark == 1:
	setup_code = "from __main__ import get_model;model=get_model()"
	full_code = f"model.sample_with_outcome(n={args.samples}, plot_results={args.plot})"
	sampling_code = f"model.sample_with_outcome(n={args.samples}, sample_only=True)"

	print("Full runtime:", timeit.timeit(full_code, setup_code, number=args.loops))
	print(
		"Sampling-only runtime:",
		timeit.timeit(sampling_code, setup_code, number=args.loops),
	)

else:
	data = {} 
	for config in [("baseline", "discrete"), ("baseline", "continuous"), ("clut", "lut_c_wrapper")]:

		data[config] = {
			"means": [],
			"variances": []
		}

		for i in tqdm(range(50)):
			model = get_model_for_tests(config)
			results = model.sample_with_outcome(n=args.samples)
			results["weights"] /= sum(results["weights"])
			
			for v in ["skill1"]:
				m = discrete_mean(
					results[v], results["weights"]
				)
				v = discrete_variance(
					results[v], results["weights"]
				)
				data[config]["means"].append(m)
				data[config]["variances"].append(v)

	print(data)

	with open("results/trueskill/stat_test.pickle", "wb") as f:
		pickle.dump(data, f)