# %%
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from stable_baselines3 import SAC
from scorers import LogReg
from data_generation import generate_synthetic_data
import os
from stable_baselines3.common.monitor import Monitor
from environments import SingleCandidateEnv, TrainingEnvWrapper, TestEnvWrapper

# %%
#Control the randomization of the framework.
random_seed = 139
rng = np.random.default_rng(random_seed)

# %%
# Generate the training set; each sample is associated with 10 features
n_continuous = 10
n_categorical = 0
X, y, continuous, categorical, means, std_devs, probs, cont_weights, feature_means, feature_std_devs, feature_ranges = generate_synthetic_data(
    n_agents=10000,
    n_continuous=n_continuous,
    n_categorical=n_categorical,
    rng=rng
)

# %%
# train the logistic regressor
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, random_state=random_seed)
model = LogReg(threshold=0.5, ignore_features=["groups"])
model.fit(X_train_val, y_train_val)

# %%
# Generate the pool of candidates; new candidates are drawn from here
X_test, _, _, _, _, _, _, _, _, _, _ = generate_synthetic_data(
    n_agents=500000,
    n_continuous=n_continuous,
    n_categorical=n_categorical,
    rng=rng,
    continuous_means=feature_means,
    continuous_std_devs=feature_std_devs,
    feature_ranges=feature_ranges
)

# %%
# Initialize feature difficulties
feature_difficulties = {'cont_0': 0.84,
 'cont_1': 0.15,
 'cont_2': 0.85,
 'cont_3': 0.78,
 'cont_4': 0.25,
 'cont_5': 0.18,
 'cont_6': 0.29,
 'cont_7': 0.83,
 'cont_8': 0.91,
'cont_9': 0.10,}

# %%
# --- Create base environment ---
steps = 9
difficulty_estimates = np.full(X_test.shape[1], 0.5)
difficulty_update_counts = np.zeros(X_test.shape[1])
log_dir = "./logs/logs2"
os.makedirs(log_dir, exist_ok=True)  # Ensure the directory exists
base_env = SingleCandidateEnv(
    X_test, model, feature_difficulties=feature_difficulties, difficulty_estimates=difficulty_estimates, difficulty_update_counts=difficulty_update_counts, steps=steps, current_idx=0, data_gen_func=generate_synthetic_data,
    seed=random_seed, phase = 1
)

# --- Wrap with training logic ---
training_env = TrainingEnvWrapper(base_env, X_test)

# --- Wrap with Monitor for logging ---
monitored_env = Monitor(training_env, log_dir)

# %%
# Pre-trained RL agent (recourse recommender policy phi)
rl_model = SAC("MultiInputPolicy",env=monitored_env,verbose=1)
rl_model.learn(total_timesteps=30000)

# %%
rl_model.save("RL_agent2_phase1")

# %%
difficulty_estimates = base_env.difficulty_estimates
difficulty_update_counts = base_env.difficulty_update_counts
base_env_eval = SingleCandidateEnv(
    X_test, model, feature_difficulties=feature_difficulties, difficulty_estimates=difficulty_estimates, difficulty_update_counts=difficulty_update_counts, steps=steps, current_idx=0, data_gen_func=generate_synthetic_data,
    seed=random_seed, phase = 2
)

# --- Setup metrics storage ---
results = {
    "episode": [],
    "mse": [],
    "cost": [],
    "reward": []
}

# --- Run 50 episodes ---
for episode in range(50):
    # Sample new candidate and goal for each episode
    sample_idx = rng.choice(len(X_test), size=1, replace=False)
    X_ = X_test.iloc[sample_idx].reset_index(drop=True)
    probs = model.predict_proba(X_)[:, -1] #scores
    scores_ = pd.Series(probs, index=X_.index)
    current_score = scores_[0]
    goal_score = current_score + (1 - current_score) * np.random.rand()
    
    # Reset base_env and wrap in TestEnvWrapper
    test_env = TestEnvWrapper(base_env_eval, X_, goal_score)
    obs, info = test_env.reset()
    
    # Predict action using trained model
    action, _ = rl_model.predict(obs, deterministic=True)

    # Step once
    obs, reward, done, truncated, info = test_env.step(action)

    # Log metrics from base_env
    results["episode"].append(episode)
    results["mse"].append(base_env_eval.squared_error)
    results["cost"].append(base_env_eval.cost)
    results["reward"].append(reward)

# --- Convert results to DataFrame ---
results_df = pd.DataFrame(results)

# --- Save to CSV if needed ---
results_df.to_csv("evaluation_metrics_phase1.csv", index=False)

# %%
steps = 9
difficulty_estimates = base_env.difficulty_estimates
difficulty_update_counts = base_env.difficulty_update_counts
log_dir = "./logs/logs3"
os.makedirs(log_dir, exist_ok=True)  # Ensure the directory exists
base_env_phase2 = SingleCandidateEnv(
    X_test, model, feature_difficulties=feature_difficulties, difficulty_estimates=difficulty_estimates, difficulty_update_counts=difficulty_update_counts, steps=steps, current_idx=0, data_gen_func=generate_synthetic_data,
    seed=random_seed, phase = 2
)
# --- Wrap with training logic ---
training_env_phase2 = TrainingEnvWrapper(base_env_phase2, X_test)

# --- Wrap with Monitor for logging ---
monitored_env_phase2 = Monitor(training_env_phase2, log_dir)

# %%
rl_model = SAC.load("RL_agent2_phase1",env=monitored_env_phase2,verbose=1)
rl_model.learn(total_timesteps=200000)

# %%
rl_model.save("RL_agent2")
# %%
# --- Setup metrics storage ---
results = {
    "episode": [],
    "mse": [],
    "cost": [],
    "reward": []
}

# --- Run 50 episodes ---
for episode in range(50):
    # Sample new candidate and goal for each episode
    sample_idx = rng.choice(len(X_test), size=1, replace=False)
    X_ = X_test.iloc[sample_idx].reset_index(drop=True)
    probs = model.predict_proba(X_)[:, -1] #scores
    scores_ = pd.Series(probs, index=X_.index)
    current_score = scores_[0]
    goal_score = current_score + (1 - current_score) * np.random.rand()
    
    # Reset base_env and wrap in TestEnvWrapper
    test_env = TestEnvWrapper(base_env_phase2, X_, goal_score)
    obs, info = test_env.reset()
    
    # Predict action using trained model
    action, _ = rl_model.predict(obs, deterministic=True)

    # Step once
    obs, reward, done, truncated, info = test_env.step(action)

    # Log metrics from base_env
    results["episode"].append(episode)
    results["mse"].append(base_env_phase2.squared_error)
    results["cost"].append(base_env_phase2.cost)
    results["reward"].append(reward)

# --- Convert results to DataFrame ---
results_df = pd.DataFrame(results)

# --- Save to CSV if needed ---
results_df.to_csv("evaluation_metrics_phase2.csv", index=False)
# %%
