import numpy as np
import pandas as pd
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

from imitations.util.util import make_vec_env
from imitations.algorithms.adversarial.airl import AIRL
from imitations.rewards.reward_nets import BasicShapedRewardNet
from imitations.util.networks import RunningNorm
from gymnasium.envs.mujoco.swimmer_cons import SwimmerTest

FAST = True
if FAST:
    N_RL_TRAIN_STEPS = 100_000
else:
    N_RL_TRAIN_STEPS = 2_000_000

csv_file_path1 = r'.\New_Dis_Acc\AIRL_Swimmer.csv'
csv_file_path2 = r'.\New_Violation_Rate\AIRL_Swimmer.csv'

for i in range(10):  # Loop to run the experiment 10 times
    SEED = np.random.randint(10, 101)

    venv = make_vec_env(
        "Swimmer-cons",
        rng=np.random.default_rng(SEED),
        n_envs=8,
    )

    a = SwimmerTest()
    rollouts = a.rollout()

    learner = PPO(
        env=venv,
        policy=MlpPolicy,
        n_steps=1000,
        batch_size=50,
        ent_coef=0.0,
        learning_rate=0.0005,
        gamma=0.95,
        clip_range=0.1,
        vf_coef=0.1,
        n_epochs=5,
        seed=SEED,
    )
    reward_net = BasicShapedRewardNet(
        observation_space=venv.observation_space,
        action_space=venv.action_space,
        normalize_input_layer=RunningNorm,
    )
    airl_trainer = AIRL(
        demonstrations=rollouts,
        demo_batch_size=50,
        venv=venv,
        gen_algo=learner,
        reward_net=reward_net,
    )

    venv.seed(SEED)
    learner_rewards_before_training, _ = evaluate_policy(
        learner, venv, 50, return_episode_rewards=True
    )
    airl_trainer.train(N_RL_TRAIN_STEPS)
    venv.seed(SEED)
    learner_rewards_after_training, _ = evaluate_policy(
        learner, venv, 50, return_episode_rewards=True
    )

    # 在两个CSV文件中记录迭代完成
    df = pd.DataFrame({'iteration': [i + 1], 'seed': [SEED]})
    with open(csv_file_path1, encoding='utf-8-sig', mode='a', newline='') as f:
        df.to_csv(f, header=f.tell() == 0, index=False)
    with open(csv_file_path2, encoding='utf-8-sig', mode='a', newline='') as f:
        df.to_csv(f, header=f.tell() == 0, index=False)

    print(f"Experiment {i + 1} with SEED {SEED}:")
    print("Iteration completed and recorded.")
    print()
