#!/usr/bin/env python3

from bidding import *
from enum import Enum
import scipy.stats as stats
import argparse


class Rand_mu(Enum):
    UNIFORM = 1
    GAUSSIAN_SPREAD = 2
    GAUSSIAN_NARROW = 3
    
class Rand_p(Enum):
    UNIFORM = 1
    EQUIPROBABLE = 2

    

mu_k = 4
mu_min = 1
mu_max = 10000
mu_gauss_average = (mu_min+mu_max)/2
mu_sigma_spread = 4000
mu_sigma_narrow = 2000

def draw_random_mu(rand_mu, rand_p):
  
  if rand_p == Rand_p.EQUIPROBABLE:
    p = [1/mu_k] * mu_k
  else:
    p = [np.random.rand() for i in range(mu_k)]
    p = np.divide(p,np.sum(p)) 

  if rand_mu == Rand_mu.UNIFORM:
    mu = [np.random.rand()*(mu_max-mu_min)+mu_min for i in range(mu_k)]
  else: 
    sigma = mu_sigma_narrow if rand_mu == Rand_mu.GAUSSIAN_NARROW else mu_sigma_spread
    mu = stats.truncnorm.rvs( (mu_min-mu_gauss_average)/sigma, (mu_max-mu_gauss_average)/sigma, loc=mu_gauss_average, scale=sigma, size = mu_k )
  
  mu.sort()
  return list(zip(mu,p))
  
  
  
  

def compute_samples(filename, nb_samples, rand_mu, rand_p, rmin, rmax, rstep):
  

  with open(filename,'w') as outfile:
    outfile.write("#\tnb_samples\trand_mu\trand_p\trmin\trmax\trstep\tmu_k\tmu_min\tmu_max\tmu_gauss_average\tmu_sigma_spread\tmu_sigma_narrow\n")
    outfile.write(f"#\t{nb_samples}\t{rand_mu}\t{rand_p}\t{rmin}\t{rmax}\t{rstep}\t{mu_k}\t{mu_min}\t{mu_max}\t{mu_gauss_average}\t{mu_sigma_spread}\t{mu_sigma_narrow}\n")
    
    
    samples_mu = [draw_random_mu(rand_mu,rand_p) for i in range(nb_samples)]

    outfile.write("\n# list of \{(mu_i,p_i)\}: \n# " +  f"{samples_mu}\n")
    
    r_list = np.arange(rmin, rmax, rstep)
    
    outfile.write("\n\n# r\talgo\tzeta1\tzeta2\tzeta0\n")
    for r in r_list:
      for mu in samples_mu:
        alg = consistency( pll_pareto_optimal(r,mu)[0] , mu)
        zeta1 = get_zeta_consistency(1,r,mu)
        zeta2 = get_zeta_consistency(2,r,mu)
        zeta0 = get_zeta_consistency(0,r,mu)
        outfile.write(f"{r}\t{alg}\t{zeta1}\t{zeta2}\t{zeta0}\n")
  
  
# which_rand: parameter to specify if we compute only one combination
def compute_all_data(which_rand = -1):
  
  nb_samples = 10
  rmin = 4
  rmax = 12.1
  rstep = 1
  
  rand_mu = [Rand_mu.UNIFORM, Rand_mu.GAUSSIAN_NARROW, Rand_mu.GAUSSIAN_SPREAD]
  rand_p = [Rand_p.UNIFORM, Rand_p.EQUIPROBABLE]

  rand_list = [(a,b) for a in rand_mu for b in rand_p] 
  
  if which_rand > -1 :
    rand_list = [rand_list[which_rand]]
    
  print(rand_list)

  for mu,p in rand_list:
    filename = f"out.{mu}.{p}"
    compute_samples(filename, nb_samples, mu, p, rmin , rmax, rstep)
  

def try_adversarial_mu(target):
  
  filename = f"out-adversarial-{target}"
  
  with open(filename,'w') as outfile:
     
    r_list = np.arange(4, 12.1,1 )
    
    outfile.write("\n\n# r\talgo\tzeta1\tzeta2\tzeta0\n")

    for r in r_list:
      mu_1_list = [5000]
      if target == "zeta":
        factor = (r+np.sqrt(r*(r-4)))/2 * 0.5 
      else:
        factor = r/4
        
        
      samples_mu = [ [(mu_1,0.5),(mu_1*(factor+0.5),0.5)]  for mu_1 in mu_1_list]
      for mu in samples_mu:
        alg = consistency( pll_pareto_optimal(r,mu)[0] , mu)
        zeta1 = get_zeta_consistency(1,r,mu)
        zeta2 = get_zeta_consistency(2,r,mu)
        zeta0 = get_zeta_consistency(0,r,mu)
        outfile.write(f"{r}\t{alg}\t{zeta1}\t{zeta2}\t{zeta0}\n")
        print(f"{r}\t{alg}\t{zeta1}\t{zeta2}\t{zeta0}\n")
  
  
  
parser = argparse.ArgumentParser("plot file")
parser.add_argument("which_rand", help="A 0-6 integer describing which randomness generation configuraion should be run.", type=int)
which_rand = parser.parse_args().which_rand

if which_rand == 6:
  try_adversarial_mu("zeta")
  try_adversarial_mu("r")
else:
  compute_all_data(which_rand)
  
  
  
  
