#!/usr/bin/env python3

from pulp import *
import numpy as np
import scipy as sp
from multiprocessing import Pool
import time
from functools import partial
import gc


## format
# mu : list of k couples (mu_i, p_i) where mu_i are increasing values and p_i are the probabilities associated
# configuration : k numbers c_i such that c_i bids occur between mu_{i-1} and mu_i
# r : robustness target
# note: the LP adds one more bid after the last mu_i to ensure robustness for the last one
def solve_config(configuration, mu, r):
  assert len(mu) == len(configuration)
  
  # + 1 as there is a last bid after the last mu
  nb_bids = sum(configuration) + 1
  cum_conf = np.cumsum(configuration)
  
  prob = LpProblem("OPT_bidding", LpMinimize)
  
  # define bid variables
  x = [LpVariable("x"+str(i), lowBound=0, cat='Integer') for i in range(nb_bids)]

  ## objective function: each mu_i charges p_i per bid smaller than mu_i and the following bid
  prob += lpSum( p_i * x[i]   for j_i,(mu_i,p_i) in zip(cum_conf,mu)   for i in range(j_i+1)  )

  ## constraints
  
  # 1st bid respects robustness
  prob += ( x[0] <= r )
  
  # x_i increasing
  for i in range(nb_bids-1):
    prob += ( x[i] <= x[i+1] )

  # configuration definition
  for j_i,(mu_i,p_i) in zip(cum_conf,mu):
    if j_i>0:
      prob += ( x[j_i-1] <= mu_i )
    prob += ( x[j_i] >= mu_i )

  # partial robustness
  for i in range(1,nb_bids):
    prob += ( x[i] <= r*x[i-1] - lpSum(x[j] for j in range(i)) )
  
  # extendable
  prob += ( lpSum(x[i] for i in range(nb_bids)) <= (r+np.sqrt(r*(r-4)))/2 * x[nb_bids-1] )
  
  
  prob.solve(PULP_CBC_CMD(msg=False))
  solution = value(prob.objective)


  gc.collect()
  
  if prob.status == LpStatusOptimal:
    return (value(prob.objective) , [v.varValue for v in prob.variables()] )
  return None, None
  

# k = nb of intervals
# nb_bids = nb of bids to share among intervals
# init_conf = all at 0
# next conf = next one in reverse lexicographic order, all bids in the first interval if called after init_conf
def first_conf(k, nb_bids):
  l = [0]*k
  return l

# return 1 if last conf reached, 0 otherwise and update configuration
def next_conf(nb_bids, configuration):

  if configuration[-1] != 0:
    # last conf, all bids at the end, return error code
    if configuration[-1] == nb_bids:
      return 1
    # move the last bid right after the penultimate nonzero interval, and shift forward another bid
    tmp = configuration[-1]
    configuration[-1] = 0
    lastlast_nz = np.max(np.nonzero(configuration))
    configuration[lastlast_nz] -= 1
    configuration[lastlast_nz+1] = tmp+1
  else:
    nz = np.nonzero(configuration)
    if len(nz[0]) == 0:
      configuration[0] = nb_bids
    else:
      last_nz = np.max(nz)
      configuration[last_nz] -= 1
      configuration[last_nz+1] += 1
    
  return 0


class MyConf:
  def __init__(self,k,max_nb_bids):
    self.l = first_conf(k,0)
    self.cur_nb_bids = 1
    self.max_nb_bids = max_nb_bids

  def __iter__(self):
    return self

  def __next__(self):
    if next_conf(self.cur_nb_bids, self.l) == 0 :
      return list(self.l)
    if self.cur_nb_bids == self.max_nb_bids :
      raise StopIteration
    self.cur_nb_bids += 1
    self.l = first_conf(len(self.l),0)
    next_conf(self.cur_nb_bids, self.l)
    return list(self.l)




# get how many bids happen before the last mu
# restrict the number of bids to save computation time as all samples were tested and use less bids
def get_nb_bids(r, mu):
  if r < 4.9:
    return 15
  else:
    return 11


# compute the best strategy, to be aggressively r-extended
def pareto_optimal(r, mu):
  max_nb_bids = get_nb_bids(r,mu)
  best = None
  best_sol = None
  
  for nb_bids in (range(1,max_nb_bids+1)):
    configuration = first_conf(len(mu), nb_bids)
    while next_conf(nb_bids, configuration) == 0:
      obj, sol = solve_config(configuration, mu, r)
      if obj == None:
        continue
      if best == None or best > obj :
        best, best_sol = obj, sol
  return best, best_sol
  
# compute in parallel the best strategy, to be aggressively r-extended
def pll_pareto_optimal(r, mu):
  print(r)

  max_nb_bids = get_nb_bids(r,mu)
  best = None
  best_sol = None
  
  confs = []
  
  it_conf = MyConf(len(mu),max_nb_bids)
    
  with Pool(processes=7) as pool:
    it = pool.imap(partial(solve_config, mu=mu, r=r), it_conf, chunksize=20) 
    for (obj,sol) in (it):
      if obj == None:
        continue
      if best == None or best > obj :
        best, best_sol = obj, sol
        
  
  gc.collect()
  
  print(best_sol)
        
  return best, best_sol

def consistency(obj, mu):
  E_opt = sum([mu_i*p_i for (mu_i,p_i) in mu])
  if (obj != None):
    return obj / E_opt
  return obj


# zeta = 0 1 or 2
# get the consistency of the best strategy of the form lambda*zeta^i
def get_zeta_consistency(zeta_which, r, mu):
  if zeta_which == 1:
    zeta = (r-np.sqrt(r*(r-4)))/2
  else:
    if zeta_which == 2:
      zeta = (r+np.sqrt(r*(r-4)))/2
    else:
      zeta = r*0.5
  
  
  if zeta < 1:
    return None
    
  best_obj = None
  
  for (mu_i,p_i) in mu :
    # test a shift such that there is bid at mu_i, compute the associated cost
    first_bid = 1.001 * mu_i / ( zeta ** np.floor(np.log(mu_i)/np.log(zeta)) ) # rounding errors
    if first_bid > r:
      first_bid /= zeta
    
    cum_cost = 0
    for (mu_j,p_j) in mu :
      cum_cost += p_j * first_bid * (zeta ** (np.ceil(np.log(mu_j/first_bid)/np.log(zeta))+1) -1) / (zeta-1) 
    
    if best_obj == None or cum_cost < best_obj:
      best_obj = cum_cost
    
     
  return consistency(best_obj, mu)  
  
  
  
  
  
  
  
