# %%
import numpy as np
from sklearn.model_selection import train_test_split
from stable_baselines3 import SAC
from scorers import LogReg
from data_generation import generate_synthetic_data
import os
from stable_baselines3.common.monitor import Monitor
from environments import SingleCandidateEnv, TestEnvWrapper, CandidateEnv

# %%
#Control the randomization of the framework.
random_seed = 139
rng = np.random.default_rng(random_seed)

# %%
# Generate the training set; each sample is associated with 10 features
n_continuous = 10
n_categorical = 0
X, y, continuous, categorical, means, std_devs, probs, cont_weights, feature_means, feature_std_devs, feature_ranges = generate_synthetic_data(
    n_agents=10000,
    n_continuous=n_continuous,
    n_categorical=n_categorical,
    rng=rng
)

# %%
# train the logistic regressor
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, random_state=random_seed)
model = LogReg(threshold=0.5, ignore_features=["groups"])
model.fit(X_train_val, y_train_val)

# %%
# Generate the pool of candidates; new candidates are drawn from here
X_test, _, _, _, _, _, _, _, _, _, _ = generate_synthetic_data(
    n_agents=500000,
    n_continuous=n_continuous,
    n_categorical=n_categorical,
    rng=rng,
    continuous_means=feature_means,
    continuous_std_devs=feature_std_devs,
    feature_ranges=feature_ranges
)

# %%
# Initialize feature difficulties
feature_difficulties = {'cont_0': 0.84,
 'cont_1': 0.15,
 'cont_2': 0.85,
 'cont_3': 0.78,
 'cont_4': 0.25,
 'cont_5': 0.18,
 'cont_6': 0.29,
 'cont_7': 0.83,
 'cont_8': 0.91,
'cont_9': 0.10,}

# %%
# --- Create base environment ---
difficulty_estimates = np.full(X_test.shape[1], 0.5)
difficulty_update_counts = np.zeros(X_test.shape[1])
base_single_env = SingleCandidateEnv(
    X_test, model, feature_difficulties=feature_difficulties, difficulty_estimates=difficulty_estimates, difficulty_update_counts=difficulty_update_counts, steps=0, current_idx=0, data_gen_func=generate_synthetic_data, seed=random_seed,
)

# --- Wrap with test wrapper ---
sample_idx = rng.choice(len(X_test), size=1, replace=False)
X_ = X_test.iloc[sample_idx].reset_index(drop=True)
goal_score = np.random.rand()
test_env = TestEnvWrapper(base_single_env, X_, goal_score)

# %%
# Pre-trained RL agent (recourse recommender policy phi)
rl_model2 = SAC.load("RL_agent2",env=test_env)

# %%
steps = 100
log_dir = "./logs/logs1"  # Specify a directory
os.makedirs(log_dir, exist_ok=True)  # Ensure the directory exists
base_env = CandidateEnv(X_test, model, feature_difficulties=feature_difficulties, base_single_env=base_single_env,
                         rl_model2=rl_model2, steps=100, decay_factor_distance=0.7, decay_factor_num=0.03, 
                         decay_factor_combination=0.07, current_idx=0, threshold=9, growth_k=10, 
                         data_gen_func=generate_synthetic_data, seed = random_seed, t_validity=1, beta = 0.05, 
                         alpha = 7, tau = 5,  method = "ours", baseline = False)
monitored_env = Monitor(base_env, log_dir)

# %%
rl_model = SAC("MultiInputPolicy",env=monitored_env,verbose=1)
rl_model.learn(total_timesteps=700000)

# %%
rl_model.save("RL_agent1")

# %%
num_episodes = 10
episode_rewards = []        # Total reward per episode
per_step_rewards = []       # Reward at each time step for each episode

for episode in range(num_episodes):
    obs, _ = base_env.reset()
    done = False
    total_reward = 0
    step_rewards = []

    while not done:
        action, _ = rl_model.predict(obs, deterministic=True)
        obs, reward, done, _, _ = base_env.step(action)
        step_rewards.append(reward)
        total_reward += reward

    episode_rewards.append(total_reward)
    per_step_rewards.append(step_rewards)
    print(f"Episode {episode + 1}: Reward = {total_reward}")

# %%
avg_reliabilities = []
for reliability_sum in base_env.reliabilities_sum:
    avg_reliabilities.append(reliability_sum/steps)
print("Average recourse reliability over episodes:", np.mean(avg_reliabilities))

# %%
avg_ginis = []
for gini_sum in base_env.ginis_sum:
    avg_ginis.append(gini_sum/steps)
print("Average Gini coefficient over episodes:", np.mean(avg_ginis))

# %%
avg_cost = []
for cost_sum in base_env.cost_sum:
    avg_cost.append(cost_sum/steps)
print("Average cost over episodes:", np.mean(avg_cost))

# %%
avg_implementing = []
for implementing_sum in base_env.implementing_sum:
    avg_implementing.append(implementing_sum/steps)
print("Average recourse feasibility over episodes:", np.mean(avg_implementing))