# -*- coding: utf-8 -*-
"""
Created on Sat Nov 11 22:16:58 2023

@author: wsb15
"""
import numpy as np
import math
from typing import Callable
from abc import ABC, abstractmethod
from generative_model import Generative_Model,Hard_MDP_Wang
import multiprocessing
from rl_mb_AMDP import RL_MB_AMDP


def run_pmbp_n_times_avg_err_with_tmix(n,eps,tmix,flag):
    hard_instance_rl = RL_MB_AMDP(Hard_MDP_Wang(1/tmix),tmix,perform_value_iteration = False)
    avg_alph = 0
    n = int(n)
    for i  in range(n):
        alph_est,samp_sz = hard_instance_rl.pmbp(eps, Jin = flag)
        avg_alph += alph_est
    avg_alph /= n
    return np.array([eps,tmix,samp_sz,avg_alph,n])

if __name__ == '__main__':
    rep = 300
    parms = []
    num_cores = multiprocessing.cpu_count()
    print(num_cores)
    
    #partition to run Wang
    eps = 0.1
    for tmix in np.linspace(10, 1000,10):
        rep_per_batch = int(rep*1.0*5/tmix)
        if rep_per_batch <= 5:
            rep_per_batch = 5
        temp_rep = rep
        while temp_rep > 0:
            temp_rep -=  rep_per_batch
            if temp_rep >= 0:
                parms.append((rep_per_batch,eps,tmix,False))
            else:
                parms.append((temp_rep+ rep_per_batch,eps,tmix,False))
    
    pool = multiprocessing.Pool(num_cores)
    output = pool.starmap(run_pmbp_n_times_avg_err_with_tmix,parms)
    np.save('tmix_data_0.1',np.array(output))