import argparse
import random
import os, json
import gym
import d4rl

import numpy as np
import torch

from offlinerlkit.nets import MLP
from offlinerlkit.modules import ActorProb, EnsembleCritic, TanhDiagGaussian
from offlinerlkit.utils.load_dataset import qlearning_dataset
from offlinerlkit.buffer import ReplayBuffer
from offlinerlkit.utils.logger import Logger, make_log_dirs
from offlinerlkit.policy_trainer import MFPolicyTrainer
from offlinerlkit.policy import RORLPolicy
from offlinerlkit.utils.scaler import StandardScaler

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="rorl2")
    parser.add_argument("--task", type=str, default="walker2d-medium-expert-v2")
    parser.add_argument("--seed", type=int, default=1)
    
    parser.add_argument("--actor-lr", type=float, default=3e-4)
    parser.add_argument("--critic-lr", type=float, default=3e-4)
    parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256, 256])
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--alpha", type=float, default=0.2)
    parser.add_argument("--auto-alpha", type=bool, default=True)
    parser.add_argument("--target-entropy", type=int, default=None)
    parser.add_argument("--alpha-lr", type=float, default=1e-4)
    parser.add_argument("--num-critics", type=int, default=10)
    parser.add_argument("--max-q-backup", type=bool, default=False)
    parser.add_argument("--deterministic-backup", type=bool, default=False)
    parser.add_argument("--normalize-reward", type=bool, default=False)
    parser.add_argument("--norm-input", action='store_true', help='Normalize observation')
    
    # === RORL Specific Arguments ===
    # Smoothing / Consistency
    parser.add_argument('--num-samples', default=20, type=int)
    parser.add_argument('--policy-smooth-eps', default=0.0, type=float) 
    parser.add_argument('--policy-smooth-reg', default=1.0, type=float)
    parser.add_argument('--q-smooth-eps', default=0.0, type=float) 
    parser.add_argument('--q-smooth-reg', default=0.005, type=float) 
    parser.add_argument('--q-smooth-tau', default=0.2, type=float)
    
    # OOD Conservative Penalty (Newly Added)
    parser.add_argument('--q-ood-eps', default=0.0, type=float)
    parser.add_argument('--q-ood-reg', default=0.0, type=float)
    parser.add_argument('--q-ood-uncertainty-reg', default=0.0, type=float)
    parser.add_argument('--q-ood-uncertainty-reg-min', default=0.0, type=float)
    parser.add_argument('--q-ood-uncertainty-decay', default=1e-6, type=float)
    
    parser.add_argument('--obs-std', default=1.0, type=float)

    parser.add_argument("--epoch", type=int, default=3000)
    parser.add_argument("--early_stop_epoch_number", type=int, default=3000)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--eval_episodes", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")

    # for robustness experiments
    parser.add_argument("--drop_ratio", type=float, default=0.0)
    parser.add_argument("--noisy_ratio", type=float, default=1.0)
    parser.add_argument("--noise_scale", type=float, default=0.05)

    # for rorl smooth hyparameters
    parser.add_argument("--config_file", type=str, default=None, help="Path to the RORL hyperparameter json file")

    return parser.parse_args()


def train(load_path):
    args = get_args()
    if args.config_file is not None:
        if not os.path.exists(args.config_file):
            raise FileNotFoundError(f"Config file not found: {args.config_file}")
            
        print(f"Loading RORL Hyperparameters from: {args.config_file}")
        with open(args.config_file, 'r') as f:
            config_dict = json.load(f)
            
        # 遍历 JSON，强制覆盖 args 中的参数
        for key, value in config_dict.items():
            if hasattr(args, key):
                setattr(args, key, value)
                print(f"  - Override {key}: {value}")
            else:
                # 如果 JSON 里有 args 没有定义的参数，也可以选择添加或者报警
                # 这里为了灵活性选择动态添加（例如 noisy_ratio 可能不在原定义中）
                setattr(args, key, value)
                print(f"  - Set new arg {key}: {value}")

    if load_path is not None:
        json_file = load_path + '/hyper_param.json'
        with open(json_file, 'r') as file:
            new_args_dict = json.load(file)
        
        update_terms = ['task', 'seed', 'noise_scale', 'early_stop_epoch_number']
        args_dict = vars(args)
        for k in update_terms:
            if k in new_args_dict:
                args_dict[k] = new_args_dict[k]
        args = argparse.Namespace(**args_dict)

    env = gym.make(args.task)
    dataset = qlearning_dataset(env)
    
    if args.normalize_reward:
        mu, std = dataset["rewards"].mean(), dataset["rewards"].std()
        dataset["rewards"] = (dataset["rewards"] - mu) / (std + 1e-3)

    args.obs_shape = env.observation_space.shape
    args.action_dim = np.prod(env.action_space.shape)
    args.max_action = env.action_space.high[0]

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    env.seed(args.seed)

    scaler = None
    if args.norm_input:
        scaler = StandardScaler()
        scaler.fit(dataset["observations"])
        args.obs_std = 1.0 
    else:
        args.obs_std = dataset["observations"].std(axis=0).mean()

    actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims)
    dist = TanhDiagGaussian(
        latent_dim=getattr(actor_backbone, "output_dim"),
        output_dim=args.action_dim,
        unbounded=True,
        conditioned_sigma=True,
        max_mu=args.max_action
    )
    actor = ActorProb(actor_backbone, dist, args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    
    critics = EnsembleCritic(
        np.prod(args.obs_shape), args.action_dim, \
        args.hidden_dims, num_ensemble=args.num_critics, \
        device=args.device
    )
    
    for layer in critics.model[::2]:
        torch.nn.init.constant_(layer.bias, 0.1)
    torch.nn.init.uniform_(critics.model[-1].weight, -3e-3, 3e-3)
    torch.nn.init.uniform_(critics.model[-1].bias, -3e-3, 3e-3)
    critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr)

    if args.auto_alpha:
        target_entropy = args.target_entropy if args.target_entropy \
            else -np.prod(env.action_space.shape)
        args.target_entropy = target_entropy
        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        alpha = (target_entropy, log_alpha, alpha_optim)
    else:
        alpha = args.alpha

    policy = RORLPolicy(
        actor,
        critics,
        actor_optim,
        critics_optim,
        tau=args.tau,
        gamma=args.gamma,
        alpha=alpha,
        max_q_backup=args.max_q_backup,
        deterministic_backup=args.deterministic_backup,
        # Smoothing
        num_samples=args.num_samples,
        policy_smooth_eps=args.policy_smooth_eps,
        policy_smooth_reg=args.policy_smooth_reg,
        q_smooth_eps=args.q_smooth_eps,
        q_smooth_reg=args.q_smooth_reg,
        q_smooth_tau=args.q_smooth_tau,
        # OOD
        q_ood_eps=args.q_ood_eps,
        q_ood_reg=args.q_ood_reg,
        q_ood_uncertainty_reg=args.q_ood_uncertainty_reg,
        q_ood_uncertainty_reg_min=args.q_ood_uncertainty_reg_min,
        q_ood_uncertainty_decay=args.q_ood_uncertainty_decay,
        
        obs_std=args.obs_std,
        scaler=scaler,
        device=args.device
    )

    buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )
    buffer.load_dataset(dataset)

    log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args), 
                             record_params=["num_critics","noise_scale", "policy_smooth_eps","policy_smooth_reg","q_smooth_eps", "q_smooth_reg", "norm_input"])
    output_config = {
        "consoleout_backup": "stdout",
        "policy_training_progress": "csv",
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)
    logger.log_hyperparameters(vars(args))

    policy_trainer = MFPolicyTrainer(
        noisy_ratio=args.noisy_ratio,
        noise_scale=args.noise_scale,
        policy=policy,
        eval_env=env,
        buffer=buffer,
        logger=logger,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.batch_size,
        eval_episodes=args.eval_episodes,
        early_stop_epoch_number=args.early_stop_epoch_number
    )
    
    policy_trainer.train()

if __name__ == "__main__":
    current_working_directory = os.getcwd()
    load_path_ls = ['/data/hc-med-exp/seed-0', '/data/hc-med-exp/seed-1', '/data/hc-med-exp/seed-2',
                    '/data/hc-med-rep/seed-0', '/data/hc-med-rep/seed-1', '/data/hc-med-rep/seed-2',
                    '/data/hc-med/seed-0', '/data/hc-med/seed-1', '/data/hc-med/seed-2',
                    '/data/hc-rnd/seed-0', '/data/hc-rnd/seed-1', '/data/hc-rnd/seed-2',
                    '/data/hp-med-exp/seed-0', '/data/hp-med-exp/seed-1', '/data/hp-med-exp/seed-2',
                    '/data/hp-med-rep/seed-0', '/data/hp-med-rep/seed-1', '/data/hp-med-rep/seed-2',
                    '/data/hp-med/seed-0', '/data/hp-med/seed-1', '/data/hp-med/seed-2',
                    '/data/hp-rnd/seed-0', '/data/hp-rnd/seed-1', '/data/hp-rnd/seed-2',
                    '/data/wk-med-exp/seed-0', '/data/wk-med-exp/seed-1', '/data/wk-med-exp/seed-2',
                    '/data/wk-med-rep/seed-0', '/data/wk-med-rep/seed-1', '/data/wk-med-rep/seed-2',
                    '/data/wk-med/seed-0', '/data/wk-med/seed-1', '/data/wk-med/seed-2',
                    '/data/wk-rnd/seed-0', '/data/wk-rnd/seed-1', '/data/wk-rnd/seed-2']
    load_path_id = 11 # 0-6
    train(current_working_directory + load_path_ls[load_path_id])