# -*- coding: utf-8 -*-
"""
Created on Fri Oct 15 21:20:22 2021

@author: grego
"""


import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import multiprocessing

from irl_funcs import test_irl


if __name__ == "__main__":
    '''
    The constant 'SIZE_LIST' controls what truncation values 'k' will be tested.
    Each SIZE in 'SIZE_LIST' controls the "size" of the problem, i.e. the 
    highest degree trigonometric function in the truncated basis (k=2*SIZE + 1).
    
    The constant 'GAMMA' is the discount factor of the generated MDPs,
    a float in the open interval (0,1).
    
    The constant 'COVER_SIZE' is the size of the covering \bar{S} used in 
    Algorithm 2 from the paper.
    
    The constant 'CORES' controls how many cores should be used to run the
    experiment.
    
    The constant 'PER_CORE_REP' is the number of times the the IRL problem should
    be solved for each truncation value and sampling value on each core.
    
    The constant 'MULTIPLIER_LIST' controls the values for the sample parameter 'N'
    tested for each truncation size.  For each multiplier in the list, the script will
    test each algorithm with (N = k^2 * multiplier)
    
    The constant 'seeds' is a list of three integers which are used to create
    the random IRL problems tested.
    
    
    
    '''
    
    ##### SET THESE PARAMETERS
    SIZE_LIST = [3, 4, 5]
    GAMMA = 0.7
    COVER_SIZE = 40
    
    CORES = 12
    PER_CORE_REP = 20
    MULTIPLIER_LIST = [10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70]
    seeds = [949,  370,  687] 
    ####
    
    # beta^-1 = 26.549686 for the default seeds given above
    
    # More seeds listed below.  Associated value of beta should be correct,
    # when using the same Numpy version at least.
    
    # seeds = [963, 454, 971] # beta^-1 = 19.4
    # seeds = [267, 750, 237] # beta^-1 = 18.350708 
    # seeds = [663,  333,  289] # beta^-1 = 19.111889
    # seeds = [461,  531,  510] # beta^-1 = 19.239672
    # seeds = [96,  503,  706]  # beta^-1 = 19.753266 
    
    
    for size in SIZE_LIST:
        
        print('-'*30)
        print('SIZE EQUALS:', size)
        print('-'*30)
    
        
        true_k = 2 * size + 1
        
        # You can also just directly set N_list if you want to test the performance
        # fixed values of 'N'
        N_list = [pow(true_k, 2) * mult for mult in MULTIPLIER_LIST]
        
        arg_list = CORES * [[size, GAMMA, COVER_SIZE, seeds, N_list, PER_CORE_REP]]
        
        with multiprocessing.Pool() as pool:
            
            df_list = pool.starmap(test_irl, arg_list)
    
        result_df = pd.concat(df_list, axis=0)
        result_df.columns = ['N', 'Correctness']
        result_df.to_csv('sc_949_Size=' + str(size) + '.csv', index=False)
            