import argparse
import random
import gym
import d4rl
import numpy as np
import torch
import os
import json

from offlinerlkit.utils.load_dataset import qlearning_dataset
from offlinerlkit.buffer import ReplayBuffer
from offlinerlkit.utils.logger import Logger, make_log_dirs
from offlinerlkit.policy_trainer import MFPolicyTrainer

# Import RFQI components
from offlinerlkit.modules.rfqi_module import RFQIActor, RFQICritic, RFQIVAE
from offlinerlkit.policy.model_free.rfqi_policy import RFQIPolicy

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="rfqi1")
    parser.add_argument("--task", type=str, default="hopper-medium-v2")
    parser.add_argument("--seed", type=int, default=0)
    
    # === 关键修正：RFQI 原始参数对齐 (参照 train_rfqi.py) ===
    parser.add_argument("--actor-lr", type=float, default=3e-4)
    parser.add_argument("--critic-lr", type=float, default=1e-3)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    
    # RFQI 特有参数
    parser.add_argument("--rho", type=float, default=0.5)
    parser.add_argument("--lmbda", type=float, default=0.75) 
    parser.add_argument("--phi", type=float, default=0.1)   # 修正：train_rfqi.py 默认为 0.1
    parser.add_argument("--adam-lr", type=float, default=3e-4) # 用于内部 ETA 优化
    parser.add_argument("--adam-eps", type=float, default=1e-6) # 修正：添加 Adam epsilon 参数
    
    parser.add_argument("--epoch", type=int, default=500)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--batch-size", type=int, default=1000) # RFQI 使用大 BatchSize
    parser.add_argument("--eval_episodes", type=int, default=10)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    # Robustness Experiments Args
    parser.add_argument("--drop_ratio", type=float, default=0.0)
    parser.add_argument("--noisy_ratio", type=float, default=0.0)
    parser.add_argument("--noise_scale", type=float, default=0.05)
    parser.add_argument("--early_stop_epoch_number", type=int, default=500)
    
    return parser.parse_args()

def train(load_path=None):
    args = get_args()

    # 1. Load Hyper-parameters from JSON
    if load_path is not None:
        json_file = os.path.join(load_path, 'hyper_param.json')
        if os.path.exists(json_file):
            with open(json_file, 'r') as file:
                new_args_dict = json.load(file)
            blocked_terms = ['device', 'algo_name','actor_lr', 'critic_lr', 'rho', 'lmbda', 'phi', 'adam_lr', 'adam_eps','epoch']
            args_dict = vars(args)
            for k, v in new_args_dict.items():
                if k in blocked_terms: continue
                if k in args_dict: args_dict[k] = v
            args = argparse.Namespace(**args_dict)
            print(f"Loaded hyperparameters from {json_file}")
    
    # ==========================================================
    # 2. Environment Setup
    # ==========================================================
    env = gym.make(args.task)
    dataset = qlearning_dataset(env)
    
    args.obs_shape = env.observation_space.shape
    args.action_dim = np.prod(env.action_space.shape)
    args.max_action = env.action_space.high[0]

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True

    # 3. Networks
    latent_dim = args.action_dim * 2
    
    actor = RFQIActor(
        state_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim,
        max_action=args.max_action,
        phi=args.phi
    ).to(args.device)
    
    critic = RFQICritic(
        state_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim
    ).to(args.device)
    
    vae = RFQIVAE(
        state_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim,
        latent_dim=latent_dim,
        max_action=args.max_action,
        device=args.device
    ).to(args.device)

    # 4. Optimizers
    # 修正：添加 eps 参数以匹配原始代码
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr, eps=args.adam_eps)
    critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr, eps=args.adam_eps)
    vae_optim = torch.optim.Adam(vae.parameters(), lr=args.actor_lr, eps=args.adam_eps)

    # 5. Policy
    policy = RFQIPolicy(
        actor, critic, vae,
        actor_optim, critic_optim, vae_optim,
        state_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim,
        tau=args.tau,
        gamma=args.gamma,
        rho=args.rho,
        lmbda=args.lmbda,
        adam_lr=args.adam_lr,
        adam_eps=args.adam_eps, # 传入 Policy 用于内部 ETA 优化器
        device=args.device
    )

    # 6. Buffer
    buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )
    buffer.load_dataset(dataset)

    # 7. Logger
    log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
    output_config = {
        "consoleout_backup": "stdout", 
        "policy_training_progress": "csv", 
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)
    logger.log_hyperparameters(vars(args))

    # 8. Trainer
    policy_trainer = MFPolicyTrainer(
        policy=policy,
        eval_env=env,
        buffer=buffer,
        logger=logger,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.batch_size,
        eval_episodes=args.eval_episodes,
        noisy_ratio=args.noisy_ratio,
        noise_scale=args.noise_scale,
        early_stop_epoch_number=args.early_stop_epoch_number
    )
    
    policy_trainer.train()

if __name__ == "__main__":
    current_working_directory = os.getcwd()
    load_path_ls = ['/data/hc-med-exp/seed-0', '/data/hc-med-exp/seed-1', '/data/hc-med-exp/seed-2',
                    '/data/hc-med-rep/seed-0', '/data/hc-med-rep/seed-1', '/data/hc-med-rep/seed-2',
                    '/data/hc-med/seed-0', '/data/hc-med/seed-1', '/data/hc-med/seed-2',
                    '/data/hc-rnd/seed-0', '/data/hc-rnd/seed-1', '/data/hc-rnd/seed-2',
                    '/data/hp-med-exp/seed-0', '/data/hp-med-exp/seed-1', '/data/hp-med-exp/seed-2',
                    '/data/hp-med-rep/seed-0', '/data/hp-med-rep/seed-1', '/data/hp-med-rep/seed-2',
                    '/data/hp-med/seed-0', '/data/hp-med/seed-1', '/data/hp-med/seed-2',
                    '/data/hp-rnd/seed-0', '/data/hp-rnd/seed-1', '/data/hp-rnd/seed-2',
                    '/data/wk-med-exp/seed-0', '/data/wk-med-exp/seed-1', '/data/wk-med-exp/seed-2',
                    '/data/wk-med-rep/seed-0', '/data/wk-med-rep/seed-1', '/data/wk-med-rep/seed-2',
                    '/data/wk-med/seed-0', '/data/wk-med/seed-1', '/data/wk-med/seed-2',
                    '/data/wk-rnd/seed-0', '/data/wk-rnd/seed-1', '/data/wk-rnd/seed-2']
    load_path_id = 35 # 0-6
    train(current_working_directory + load_path_ls[load_path_id])