import numpy as np
import multiprocessing
import math
from typing import Callable
import sys
sys.path.append('../')
from  dr_q_learning import DR_Q_learning
from  dr_rl_empirical_chi_square import DR_RL_empirical_chi_square
from abc import ABC, abstractmethod
from model.hard_mdp_unichain import Hard_MDP_Unichain
import matplotlib.pyplot as plt
from datetime import datetime
import pickle

global MAX_ITER
MAX_ITER = 5000


def sample_error_plot_rvi(mdp, gamma, delta, n_sample, n_trajectory):
    print('At sample_error_plot_rvi, the parameter is: gamma, n_sample, n_trajectory', (gamma, delta, n_sample, n_trajectory))
    dr_q_learning = DR_Q_learning(mdp, delta, gamma)
    target_g_star = dr_q_learning.g_star
    #print('target_g_star', target_g_star)

    average_error = 0
    for trajectory in range(0, n_trajectory):
        dr_q_learning.reset()
        dr_q_learning.relative_value_iteration_q(empirical = False, n_sample = n_sample)
        g_star_baseline = dr_q_learning.g_star
        error = max([abs(g_star_baseline[s] - target_g_star[s]) for s in mdp.states])
        average_error += error
    average_error = average_error/n_trajectory
    print('n_sample, average_error', (n_sample, average_error))
    return average_error

def sample_error_experiment():
    delta = 0.1
    n_min = 10
    n_max = int(1/ ((1- (1e-6)**(1/MAX_ITER))**2))
    grain = 10
    n_samples = np.logspace(np.log10(n_min), np.log10(n_max), grain)
    n_samples = np.round(n_samples).astype(int)
    n_samples = np.array([10, 32, 100, 316, 1000, 3162, 10000, 31622, 100000])
    n_samples = np.array([10])
    select_gamma = lambda n: 1- 1/np.sqrt(n)
    p = 0.9
    trajectory_for_each_par = 1
    mdps = {n_sample: Hard_MDP_Unichain(p) for n_sample in n_samples}

    print('Sample, Erorr experiment starts, with n_samples, trajectory_for_each_par', (n_samples, trajectory_for_each_par))

    input_parameter = []
    for n_sample in n_samples:
        input_parameter.append((mdps[n_sample], select_gamma(int(n_sample)), delta, int(n_sample), trajectory_for_each_par))
    
    # Use multiprocessing to run the experiment
    pool = multiprocessing.Pool()
    results = pool.starmap_async(sample_error_plot_rvi, input_parameter)
    results.wait()

    sample_error_data = {}
    for n_sample in n_samples:
        error = sample_error_plot_rvi(mdps[n_sample], select_gamma(int(n_sample)), delta, int(n_sample), trajectory_for_each_par)
        sample_error_data[n_sample] = error

        print('n_sample, error', (n_sample, error))

    print('Sample, Error experiment ends.')
    exit()
    # Save the error_plot dictionary to a file
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"rebuttal/baseline/{current_time}_{trajectory_for_each_par}_{grain}.pkl"

    with open(filename, 'wb') as f:
        pickle.dump(sample_error_data, f)
    print('Sample, Error experiment data saved to', filename)

if __name__ == '__main__':
    sample_error_experiment()
    