# -*- coding: utf-8 -*-
"""
Created on Sat Nov 11 22:16:58 2023

@author: wsb15
"""
import numpy as np
import math
from typing import Callable
from abc import ABC, abstractmethod
from generative_model import Generative_Model,Hard_MDP_Wang
import multiprocessing
from rl_mb_AMDP import RL_MB_AMDP
tmix = 10
hard_instance_rl = RL_MB_AMDP(Hard_MDP_Wang(1/tmix),tmix,perform_value_iteration = False)
def run_pmbp_n_times_avg_err(n,eps,flag):
        avg_alph = 0
        n = int(n)
        for i  in range(n):
            alph_est,samp_sz = hard_instance_rl.pmbp(eps, Jin = flag)
            avg_alph += alph_est
        avg_alph /= n
        return np.array([eps,int(flag),samp_sz,avg_alph,n])

if __name__ == '__main__':
    rep = 300
    parms = []
    num_cores = multiprocessing.cpu_count()
    print(num_cores)
    
    #partition to run Wang
    for eps in np.linspace(0.02, 0.005,8):
        rep_per_batch = int(rep*1.0/0.1**2*eps**2)
        if rep_per_batch <= 5:
            rep_per_batch = 5
        temp_rep = rep
        while temp_rep > 0:
            temp_rep -=  rep_per_batch
            if temp_rep >= 0:
                parms.append((rep_per_batch,eps,False))
            else:
                parms.append((temp_rep+ rep_per_batch,eps,False))
    #partition to run Jin
    for eps in np.linspace(0.1, 0.04,8):
        rep_per_batch = int(rep*1.0/0.1**3*eps**3/3)
        if rep_per_batch <= 5:
            rep_per_batch = 5
        temp_rep = rep
        while temp_rep > 0:
            temp_rep -=  rep_per_batch
            if temp_rep >= 0:
                parms.append((rep_per_batch,eps,True))
            else:
                parms.append((temp_rep+ rep_per_batch,eps,True))
    pool = multiprocessing.Pool(num_cores)
    output = pool.starmap(run_pmbp_n_times_avg_err,parms)
    np.save('compare_data',np.array(output))