import numpy as np
import multiprocessing
import math
from typing import Callable
import sys
sys.path.append('../')
from  dr_rl_empirical_kl import DR_RL_empirical_kl
from  dr_rl_empirical_chi_square import DR_RL_empirical_chi_square
from abc import ABC, abstractmethod
from model.large_scale import Large_MDP
import matplotlib.pyplot as plt
from datetime import datetime
import pickle

global MAX_ITER
MAX_ITER = 5000


def sample_error_plot_value_iteration(mdp, gamma, delta, n_sample, n_trajectory):
    print('at sample_error_plot_value_iteration, the parameter is: gamma, n_sample, n_trajectory', (gamma, delta, n_sample, n_trajectory))
    dr_rl_empirical = DR_RL_empirical_kl(mdp, delta, gamma)
    target_g_star = dr_rl_empirical.g_star

    average_error = 0
    for trajectory in range(0, n_trajectory):
        dr_rl_empirical.reset()
        dr_rl_empirical.value_iteration(empirical = True, n_sample = n_sample)
        v_star_emp = dr_rl_empirical.v_star
        g_star_emp = {s: (1-gamma)*v_star_emp[s] for s in mdp.states}
        error = max([abs(g_star_emp[s] - target_g_star[s]) for s in mdp.states])
        average_error += error
    average_error = average_error/n_trajectory
    print('n_sample, average_error', (n_sample, average_error))
    return average_error

def sample_error_experiment():
    delta = 0.4
    n_min = 10
    n_max = int(1/ ((1- (1e-6)**(1/MAX_ITER))**2))
    n_max = n_min
    grain = 10
    # n_samples = np.logspace(np.log10(n_min), np.log10(n_max), grain)
    # n_samples = np.round(n_samples).astype(int)
    n_samples = np.array([10, 32, 100, 316, 1000, 3162, 10000, 31622, 100000])
    select_gamma = lambda n: 1- 1/np.sqrt(n)
    p = 0.9
    seed = 42
    trajectory_for_each_par = 1
    mdps = {n_sample: Large_MDP(num_of_state = 20, num_of_action = 30, random_seed=seed) for n_sample in n_samples}

    print('Sample, Erorr experiment starts, with n_samples, trajectory_for_each_par', (n_samples, trajectory_for_each_par))

    input_parameter = []
    for n_sample in n_samples:
        input_parameter.append((mdps[n_sample], select_gamma(int(n_sample)), delta, int(n_sample), trajectory_for_each_par))
    
    # Use multiprocessing to run the experiment
    pool = multiprocessing.Pool()
    results = pool.starmap_async(sample_error_plot_value_iteration, input_parameter)
    results.wait()

    sample_error_data = {}
    for n_sample in n_samples:
        error = sample_error_plot_value_iteration(mdps[n_sample], select_gamma(int(n_sample)), delta, int(n_sample), trajectory_for_each_par)
        sample_error_data[n_sample] = error

    # Save the error_plot dictionary to a file
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"rebuttal/large_scale/{current_time}_{trajectory_for_each_par}_{grain}.pkl"

    with open(filename, 'wb') as f:
        pickle.dump(sample_error_data, f)
    print('Sample, Error experiment data saved to', filename)

if __name__ == '__main__':
    sample_error_experiment()
    