import numpy as np
import multiprocessing
import math
from typing import Callable
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from dr_rl_empirical_kl import DR_RL_empirical_kl
from dr_rl_empirical_chi_square import DR_RL_empirical_chi_square
from abc import ABC, abstractmethod
from model.hard_mdp_unichain import Hard_MDP_Unichain
import matplotlib.pyplot as plt
from datetime import datetime
import pickle
import glob

global MAX_ITER
MAX_ITER = 5000

def sample_error_plot_value_iteration(mdp, gamma, delta, n_sample, n_trajectory):
    print('Parameters: gamma, n_sample, n_trajectory', (gamma, n_sample, n_trajectory))
    # Initialize the empirical DR-RL algorithm
    # If using KL divergence, use DR_RL_empirical_kl
    # If using Chi-square divergence, use DR_RL_empirical_chi_square
    dr_rl_empirical = DR_RL_empirical_kl(mdp, delta, gamma)
    target_g_star = dr_rl_empirical.g_star

    average_error = 0
    for trajectory in range(n_trajectory):
        dr_rl_empirical.reset()
        dr_rl_empirical.value_iteration(empirical=True, n_sample=n_sample)
        v_star_emp = dr_rl_empirical.v_star
        g_star_emp = {s: (1-gamma)*v_star_emp[s] for s in mdp.states}
        error = max([abs(g_star_emp[s] - target_g_star[s]) for s in mdp.states])
        average_error += error
    
    average_error /= n_trajectory
    print(f'Completed: n_sample={n_sample}, error={average_error:.6f}')
    return average_error

def run_experiment_for_p(p_values, delta=0.1, n_min=10, grain=10, trajectory_for_each_par=1):
    n_max = int(1/((1-(1e-6)**(1/MAX_ITER))**2))
    n_samples = np.logspace(np.log10(n_min), np.log10(n_max), grain).astype(int)
    
    for p in p_values:
        print(f"\n=== Starting experiments for p={p} ===")
        mdps = {n_sample: Hard_MDP_Unichain(p) for n_sample in n_samples}
        select_gamma = lambda n: 1 - 1/np.sqrt(n)
        
        input_parameter = [(mdps[n], select_gamma(n), delta, n, trajectory_for_each_par) 
                          for n in n_samples]
        
        # Run experiments in parallel
        with multiprocessing.Pool() as pool:
            results = pool.starmap(sample_error_plot_value_iteration, input_parameter)
        
        sample_error_data = dict(zip(n_samples, results))
        
        # Save results
        os.makedirs("data", exist_ok=True)
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"data/Reduction_DMDP/kl/p_{p:.2f}_{current_time}.pkl"
        
        with open(filename, 'wb') as f:
            pickle.dump(sample_error_data, f)
        print(f"Saved results for p={p} to {filename}")

if __name__ == '__main__':
    p_values = [0.9, 0.5, 0.1]  # Example p values from 0.1 to 1.0
    
    # Step 1: Run experiments for all p values
    run_experiment_for_p(p_values, grain=20)