# import argparse
# import random
# import gym
# import d4rl
# import numpy as np
# import torch
# import torch.nn as nn
# import os
# import json

# from offlinerlkit.modules import ActorProb, TanhDiagGaussian
# from offlinerlkit.utils.load_dataset import qlearning_dataset
# from offlinerlkit.buffer import ReplayBuffer
# from offlinerlkit.utils.logger import Logger, make_log_dirs
# from offlinerlkit.policy_trainer import MFPolicyTrainer

# # Import TRACER components
# # Use the modules we fixed in tracer_module.py
# from offlinerlkit.modules.tracer_module import (
#     VectorizedCritic, # Fixed IQN Critic
#     DistributionalValueFunction,
#     ObservationModel
# )
# from offlinerlkit.policy.model_free.tracer_policy import TRACERPolicy

# def get_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--algo-name", type=str, default="tracer")
#     parser.add_argument("--task", type=str, default="halfcheetah-medium-expert-v2")
#     parser.add_argument("--seed", type=int, default=1)
#     parser.add_argument("--actor-lr", type=float, default=1e-3)
#     parser.add_argument("--critic-lr", type=float, default=1e-3)
#     parser.add_argument("--gamma", type=float, default=0.99)
#     parser.add_argument("--tau", type=float, default=0.005)
#     parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256])
    
#     # TRACER Specific Args
#     parser.add_argument("--beta", type=float, default=3.0)
#     parser.add_argument("--quantile", type=float, default=0.25)
#     parser.add_argument("--iql-tau", type=float, default=0.7)
#     parser.add_argument("--obser-sigma", type=float, default=0.3)
    
#     # Network Architecture Args
#     parser.add_argument("--num-q", type=int, default=2)
#     parser.add_argument("--num-quantiles", type=int, default=32)
#     parser.add_argument("--cosines-dim", type=int, default=64)
#     parser.add_argument("--num-model", type=int, default=3) # VAE ensemble size (must be 3 for s,a,r mask)
    
#     parser.add_argument("--epoch", type=int, default=3000)
#     parser.add_argument("--step-per-epoch", type=int, default=1000)
#     parser.add_argument("--batch-size", type=int, default=256)
#     parser.add_argument("--eval_episodes", type=int, default=10)
#     parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

#     # Robustness Experiments Args
#     parser.add_argument("--drop_ratio", type=float, default=0.0)
#     parser.add_argument("--noisy_ratio", type=float, default=0.0)
#     parser.add_argument("--noise_scale", type=float, default=0.05)
#     parser.add_argument("--early_stop_epoch_number", type=int, default=3000)
    
#     return parser.parse_args()

# def train(load_path=None):
#     args = get_args()

#     # ==========================================================
#     # 1. Load Hyper-parameters
#     # ==========================================================
#     if load_path is not None:
#         json_file = os.path.join(load_path, 'hyper_param.json')
#         if os.path.exists(json_file):
#             with open(json_file, 'r') as file:
#                 new_args_dict = json.load(file)
#             blocked_terms = [
#                 'device', 'algo_name', 
#                 'actor_lr', 'critic_lr', 
#                 'obs_shape', 'action_dim', 'max_action'
#             ]
#             args_dict = vars(args)
#             for k, v in new_args_dict.items():
#                 if k in blocked_terms: continue
#                 if k in args_dict: args_dict[k] = v
#             args = argparse.Namespace(**args_dict)
#             print(f"Loaded hyperparameters from {json_file}")

#     # ==========================================================
#     # 2. Environment Setup
#     # ==========================================================
#     env = gym.make(args.task)
#     dataset = qlearning_dataset(env)
    
#     args.obs_shape = env.observation_space.shape
#     args.action_dim = np.prod(env.action_space.shape)
#     args.max_action = env.action_space.high[0]

#     random.seed(args.seed)
#     np.random.seed(args.seed)
#     torch.manual_seed(args.seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed_all(args.seed)
#         torch.backends.cudnn.deterministic = True

#     # ==========================================================
#     # 3. Network Initialization (Fixed for TRACER)
#     # ==========================================================
    
#     # MLP helper for Actor
#     class MLP(nn.Module):
#         def __init__(self, input_dim, hidden_dims, output_dim=None):
#             super().__init__()
#             layers = []
#             prev_dim = input_dim
#             for dim in hidden_dims:
#                 layers.append(nn.Linear(prev_dim, dim))
#                 layers.append(nn.ReLU())
#                 prev_dim = dim
#             if output_dim:
#                 layers.append(nn.Linear(prev_dim, output_dim))
#             self.net = nn.Sequential(*layers)
#             self.output_dim = prev_dim # Latent dim for TanhDiagGaussian
#         def forward(self, x): return self.net(x)

#     # Actor (Standard TanhGaussian)
#     actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims)
#     dist = TanhDiagGaussian(
#         latent_dim=actor_backbone.output_dim,
#         output_dim=args.action_dim,
#         unbounded=True,
#         conditioned_sigma=True,
#         max_mu=args.max_action
#     )
#     actor = ActorProb(actor_backbone, dist, args.device)
    
#     # Critic (IQN Style - Vectorized)
#     critic = VectorizedCritic(
#         state_dim=np.prod(args.obs_shape),
#         action_dim=args.action_dim,
#         hidden_dim=256,
#         num_q=args.num_q,
#         num_quantiles=args.num_quantiles,
#         cosines_dim=args.cosines_dim
#     ).to(args.device)
    
#     # Value Net (IQN Style)
#     value_net = DistributionalValueFunction(
#         state_dim=np.prod(args.obs_shape),
#         hidden_dim=256,
#         num_v=1, # Value net usually isn't ensembled
#         cosines_dim=args.cosines_dim
#     ).to(args.device)
    
#     # Observation Model (Masked VAE)
#     obser_model = ObservationModel(
#         state_dim=np.prod(args.obs_shape),
#         action_dim=args.action_dim,
#         hidden_dim=256,
#         num_model=args.num_model, # Must be 3 for the s,a,r masking logic
#         device=args.device,
#         sigma=args.obser_sigma
#     ).to(args.device)

#     # ==========================================================
#     # 4. Optimizers
#     # ==========================================================
#     actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
#     critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
#     value_optim = torch.optim.Adam(value_net.parameters(), lr=args.critic_lr) # Value usually shares LR with critic
#     obser_optim = torch.optim.Adam(obser_model.parameters(), lr=1e-3) # Model LR

#     # ==========================================================
#     # 5. Policy Setup
#     # ==========================================================
#     policy = TRACERPolicy(
#         actor, critic, value_net, obser_model,
#         actor_optim, critic_optim, value_optim, obser_optim,
#         tau=args.tau,
#         gamma=args.gamma,
#         beta=args.beta,
#         quantile=args.quantile,
#         iql_tau=args.iql_tau,
#         obser_sigma=args.obser_sigma,
#         num_quantiles=args.num_quantiles,
#         num_q=args.num_q,
#         device=args.device,
#         total_steps=args.epoch * args.step_per_epoch # For beta decay schedule
#     )

#     # ==========================================================
#     # 6. Buffer Setup
#     # ==========================================================
#     buffer = ReplayBuffer(
#         buffer_size=len(dataset["observations"]),
#         obs_shape=args.obs_shape,
#         obs_dtype=np.float32,
#         action_dim=args.action_dim,
#         action_dtype=np.float32,
#         device=args.device
#     )
#     buffer.load_dataset(dataset)

#     # ==========================================================
#     # 7. Logger & Trainer
#     # ==========================================================
#     log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args), record_params=["noise_scale"])
    
#     # log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
#     output_config = {
#         "consoleout_backup": "stdout", 
#         "policy_training_progress": "csv", 
#         "tb": "tensorboard"
#     }
#     logger = Logger(log_dirs, output_config)
#     logger.log_hyperparameters(vars(args))

#     policy_trainer = MFPolicyTrainer(
#         policy=policy,
#         eval_env=env,
#         buffer=buffer,
#         logger=logger,
#         epoch=args.epoch,
#         step_per_epoch=args.step_per_epoch,
#         batch_size=args.batch_size,
#         eval_episodes=args.eval_episodes,
#         noisy_ratio=args.noisy_ratio,
#         noise_scale=args.noise_scale,
#         early_stop_epoch_number=args.early_stop_epoch_number
#     )
    
#     policy_trainer.train()

# if __name__ == "__main__":
#     current_working_directory = os.getcwd()
#     load_path_ls = ['/data/hc-med-exp/seed-0', '/data/hc-med-exp/seed-1', '/data/hc-med-exp/seed-2',
#                     '/data/hc-med-rep/seed-0', '/data/hc-med-rep/seed-1', '/data/hc-med-rep/seed-2',
#                     '/data/hc-med/seed-0', '/data/hc-med/seed-1', '/data/hc-med/seed-2',
#                     '/data/hc-rnd/seed-0', '/data/hc-rnd/seed-1', '/data/hc-rnd/seed-2',
#                     '/data/hp-med-exp/seed-0', '/data/hp-med-exp/seed-1', '/data/hp-med-exp/seed-2',
#                     '/data/hp-med-rep/seed-0', '/data/hp-med-rep/seed-1', '/data/hp-med-rep/seed-2',
#                     '/data/hp-med/seed-0', '/data/hp-med/seed-1', '/data/hp-med/seed-2',
#                     '/data/hp-rnd/seed-0', '/data/hp-rnd/seed-1', '/data/hp-rnd/seed-2',
#                     '/data/wk-med-exp/seed-0', '/data/wk-med-exp/seed-1', '/data/wk-med-exp/seed-2',
#                     '/data/wk-med-rep/seed-0', '/data/wk-med-rep/seed-1', '/data/wk-med-rep/seed-2',
#                     '/data/wk-med/seed-0', '/data/wk-med/seed-1', '/data/wk-med/seed-2',
#                     '/data/wk-rnd/seed-0', '/data/wk-rnd/seed-1', '/data/wk-rnd/seed-2']
#     load_path_id = 21 # 0-6
#     train(current_working_directory + load_path_ls[load_path_id])

import argparse
import random
import gym
import d4rl
import numpy as np
import torch
import torch.nn as nn
import os
import json

from offlinerlkit.modules import ActorProb, TanhDiagGaussian
from offlinerlkit.utils.load_dataset import qlearning_dataset
from offlinerlkit.buffer import ReplayBuffer
from offlinerlkit.utils.logger import Logger, make_log_dirs
from offlinerlkit.policy_trainer import MFPolicyTrainer
# Fix 1: Import Scaler from offlinerl-kit
from offlinerlkit.utils.scaler import StandardScaler

# Import TRACER components
from offlinerlkit.modules.tracer_module import (
    VectorizedCritic, 
    DistributionalValueFunction, 
    ObservationModel
)
from offlinerlkit.policy.model_free.tracer_policy import TRACERPolicy

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="tracer2")
    parser.add_argument("--task", type=str, default="halfcheetah-medium-expert-v2")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--actor-lr", type=float, default=1e-3)
    parser.add_argument("--critic-lr", type=float, default=1e-3)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--hidden-dims", type=int, nargs='*', default=[256, 256])
    
    # TRACER Specific Args
    parser.add_argument("--beta", type=float, default=3.0)
    parser.add_argument("--quantile", type=float, default=0.25)
    parser.add_argument("--iql-tau", type=float, default=0.7)
    parser.add_argument("--sigma", type=float, default=0.1, help="Sigma for QR loss")
    parser.add_argument("--obser-sigma", type=float, default=0.1, help="Sigma for VAE loss")
    
    # Network Architecture Args
    parser.add_argument("--num-q", type=int, default=2)
    parser.add_argument("--num-quantiles", type=int, default=32)
    parser.add_argument("--cosines-dim", type=int, default=64)
    parser.add_argument("--num-model", type=int, default=3) 
    
    parser.add_argument("--epoch", type=int, default=3000)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--eval_episodes", type=int, default=10)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    # Robustness Experiments Args
    parser.add_argument("--drop_ratio", type=float, default=0.0)
    parser.add_argument("--noisy_ratio", type=float, default=0.0) 
    parser.add_argument("--noise_scale", type=float, default=0.05)
    parser.add_argument("--early_stop_epoch_number", type=int, default=3000)
    
    # Fix 1: Add normalization arg
    parser.add_argument("--norm-input", action='store_true', default=True, help='Normalize observation')
    
    return parser.parse_args()

def train(load_path=None):
    args = get_args()

    # 1. Load Hyper-parameters
    if load_path is not None:
        json_file = os.path.join(load_path, 'hyper_param.json')
        if os.path.exists(json_file):
            with open(json_file, 'r') as file:
                new_args_dict = json.load(file)
            blocked_terms = ['device', 'algo_name', 'actor_lr', 'critic_lr', 'obs_shape', 'action_dim', 'max_action']
            args_dict = vars(args)
            for k, v in new_args_dict.items():
                if k in blocked_terms: continue
                if k in args_dict: args_dict[k] = v
            args = argparse.Namespace(**args_dict)

    # 2. Environment Setup
    env = gym.make(args.task)
    dataset = qlearning_dataset(env)
    
    # Fix 1: Init Scaler (RORL Style)
    scaler = None
    if args.norm_input:
        scaler = StandardScaler()
        scaler.fit(dataset["observations"])
        print(f"Scaler fitted. Mean: {scaler.mu[:5]}..., Std: {scaler.std[:5]}...")
        # Note: Dataset remains RAW here. Normalization happens inside Policy.
    
    args.obs_shape = env.observation_space.shape
    args.action_dim = np.prod(env.action_space.shape)
    args.max_action = env.action_space.high[0]

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True

    # 3. Network Initialization
    class MLP(nn.Module):
        def __init__(self, input_dim, hidden_dims, output_dim=None):
            super().__init__()
            layers = []
            prev_dim = input_dim
            for dim in hidden_dims:
                layers.append(nn.Linear(prev_dim, dim))
                layers.append(nn.ReLU())
                prev_dim = dim
            if output_dim:
                layers.append(nn.Linear(prev_dim, output_dim))
            self.net = nn.Sequential(*layers)
            self.output_dim = prev_dim 
        def forward(self, x): return self.net(x)

    actor_backbone = MLP(input_dim=np.prod(args.obs_shape), hidden_dims=args.hidden_dims)
    dist = TanhDiagGaussian(
        latent_dim=actor_backbone.output_dim,
        output_dim=args.action_dim,
        unbounded=True,
        conditioned_sigma=True,
        max_mu=args.max_action
    )
    actor = ActorProb(actor_backbone, dist, args.device)
    
    # Fix 2: Critic action_dim must include reward dimension (+1)
    critic = VectorizedCritic(
        state_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim + 1, 
        hidden_dim=256,
        num_q=args.num_q,
        num_quantiles=args.num_quantiles,
        cosines_dim=args.cosines_dim
    ).to(args.device)
    
    value_net = DistributionalValueFunction(
        state_dim=np.prod(args.obs_shape),
        hidden_dim=256,
        num_v=1, 
        cosines_dim=args.cosines_dim
    ).to(args.device)
    
    obser_model = ObservationModel(
        state_dim=np.prod(args.obs_shape),
        action_dim=args.action_dim,
        hidden_dim=256,
        num_model=args.num_model,
        device=args.device,
        sigma=args.obser_sigma
    ).to(args.device)

    # 4. Optimizers
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
    critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
    value_optim = torch.optim.Adam(value_net.parameters(), lr=args.critic_lr)
    obser_optim = torch.optim.Adam(obser_model.parameters(), lr=1e-3)

    # 5. Policy Setup
    policy = TRACERPolicy(
        actor, critic, value_net, obser_model,
        actor_optim, critic_optim, value_optim, obser_optim,
        tau=args.tau,
        gamma=args.gamma,
        beta=args.beta,
        quantile=args.quantile,
        iql_tau=args.iql_tau,
        sigma=args.sigma,              # Correctly pass QR sigma
        obser_sigma=args.obser_sigma,  # Correctly pass VAE sigma
        num_quantiles=args.num_quantiles,
        num_q=args.num_q,
        device=args.device,
        scaler=scaler,
        total_steps=args.epoch * args.step_per_epoch
    )

    # 6. Buffer Setup
    buffer = ReplayBuffer(
        buffer_size=len(dataset["observations"]),
        obs_shape=args.obs_shape,
        obs_dtype=np.float32,
        action_dim=args.action_dim,
        action_dtype=np.float32,
        device=args.device
    )
    buffer.load_dataset(dataset) # Load RAW dataset

    # 7. Logger & Trainer
    log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args), record_params=["noise_scale"])
    output_config = {
        "consoleout_backup": "stdout", 
        "policy_training_progress": "csv", 
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)
    logger.log_hyperparameters(vars(args))

    policy_trainer = MFPolicyTrainer(
        policy=policy,
        eval_env=env,
        buffer=buffer,
        logger=logger,
        epoch=args.epoch,
        step_per_epoch=args.step_per_epoch,
        batch_size=args.batch_size,
        eval_episodes=args.eval_episodes,
        noisy_ratio=args.noisy_ratio,
        noise_scale=args.noise_scale,
        early_stop_epoch_number=args.early_stop_epoch_number
    )
    
    policy_trainer.train()

if __name__ == "__main__":
    current_working_directory = os.getcwd()
    load_path_ls = ['/data/hc-med-exp/seed-0', '/data/hc-med-exp/seed-1', '/data/hc-med-exp/seed-2',
                    '/data/hc-med-rep/seed-0', '/data/hc-med-rep/seed-1', '/data/hc-med-rep/seed-2',
                    '/data/hc-med/seed-0', '/data/hc-med/seed-1', '/data/hc-med/seed-2',
                    '/data/hc-rnd/seed-0', '/data/hc-rnd/seed-1', '/data/hc-rnd/seed-2',
                    '/data/hp-med-exp/seed-0', '/data/hp-med-exp/seed-1', '/data/hp-med-exp/seed-2',
                    '/data/hp-med-rep/seed-0', '/data/hp-med-rep/seed-1', '/data/hp-med-rep/seed-2',
                    '/data/hp-med/seed-0', '/data/hp-med/seed-1', '/data/hp-med/seed-2',
                    '/data/hp-rnd/seed-0', '/data/hp-rnd/seed-1', '/data/hp-rnd/seed-2',
                    '/data/wk-med-exp/seed-0', '/data/wk-med-exp/seed-1', '/data/wk-med-exp/seed-2',
                    '/data/wk-med-rep/seed-0', '/data/wk-med-rep/seed-1', '/data/wk-med-rep/seed-2',
                    '/data/wk-med/seed-0', '/data/wk-med/seed-1', '/data/wk-med/seed-2',
                    '/data/wk-rnd/seed-0', '/data/wk-rnd/seed-1', '/data/wk-rnd/seed-2']
    load_path_id = 35 # 0-6
    train(current_working_directory + load_path_ls[load_path_id])
