import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import random
from collections import deque
from torch.nn.utils.rnn import pad_sequence

from src.gift.models.encoders import HistoryEncoder
from src.gift.models.networks import Actor, Critic as EnhancedCritic
from src.gift.buffers.her_buffer import HERReplayBuffer, search_reward_threshold_adaptive
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from src.gift.utils.evaluator import evaluate_agent
def expectile_loss(diff, tau):
    """IQL中使用的Expectile Loss"""
    weight = torch.where(diff > 0, tau, 1 - tau)
    return (weight * (diff**2)).mean()

class ValueNetwork(nn.Module):
    """IQL中使用的Value Network V(s)"""
    def __init__(self, state_dim, hiddens_sac=[256, 256]):
        super(ValueNetwork, self).__init__()
        layers = []
        input_dim = state_dim
        for hidden_dim in hiddens_sac:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
        layers.append(nn.Linear(input_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, state):
        return self.net(state)


class BehaviorPolicy(nn.Module):
    """
    行为策略模型，使用Beta分布，动作范围在(0,1)之间。
    此版本使用一个列表 `hiddens_bhvr` 来动态构建隐藏层，以提供更大的灵活性。
    """
    def __init__(self, state_dim, action_dim, hiddens_bhvr=[256, 256]):
        super(BehaviorPolicy, self).__init__()
        if not hiddens_bhvr:
            raise ValueError("hiddens_bhvr 列表不能为空。")
        net_layers = []
        input_dim = state_dim
        for hidden_dim in hiddens_bhvr:
            net_layers.append(nn.Linear(input_dim, hidden_dim))
            net_layers.append(nn.ReLU())
            input_dim = hidden_dim
        self.net = nn.Sequential(*net_layers)
        self.output_layer = nn.Linear(input_dim, action_dim * 2)
        
    def forward(self, state, action=None):
        x = self.net(state)
        x = self.output_layer(x)
        alpha_logits, beta_logits = torch.chunk(x, 2, dim=-1)
        alpha = F.softplus(alpha_logits) + 1.0
        beta = F.softplus(beta_logits) + 1.0
        beta_dist = torch.distributions.Beta(alpha, beta)
        
        if action is None:
            action = beta_dist.rsample()
            log_prob = beta_dist.log_prob(action)
            log_prob = log_prob.sum(-1, keepdim=True)
            
            return action, log_prob
        else:
            action_clamped = torch.clamp(action, 1e-6, 1 - 1e-6)
            log_prob = beta_dist.log_prob(action_clamped)
            log_prob = log_prob.sum(-1, keepdim=True)
            
            return log_prob

class SAC_HER_Agent:
    def __init__(self, dataset_collection, config, input_dim=1, output_dim=1, treatment_dim=2, static_dim=1, hidden_dim=128, future_length=5,
                 discount=0.99, beta=0.005, lr=3e-4, alpha=0.2, 
                 buffer_size=200000, batch_size=128, use_automatic_entropy=True,
                 goal_threshold=1e-3, k_future=10, use_amp=False, reward_mode='combined',
                 use_attention=False, num_heads=4, DR=True, recover=False, action_diff=True,
                 use_cql=False,
                 cql_alpha=5.0,
                 cql_n_actions=10,
                 iql_tau=0.7,     
                 iql_beta=3.0,    
                 actor_update_interval=5,
                 input_x=False):

        self.hiddens_enc = config.model.hiddens_enc
        self.hiddens_sac = config.model.hiddens_sac
        self.hiddens_bhvr = config.model.hiddens_bhvr
        self.epsilon_1 = config.model.epsilon_1
        self.epsilon_2 = config.model.epsilon_2
        self.state_dim = self.hiddens_enc[-1]
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.treatment_dim = treatment_dim
        self.static_dim = static_dim
        self.hidden_dim = hidden_dim
        self.future_length = future_length
        self.discount = discount
        self.beta = beta
        self.alpha = alpha
        self.batch_size = batch_size
        self.goal_threshold = goal_threshold
        self.use_amp = use_amp
        
        self.baserl = config.model.baserl
        self.DR = DR
        if self.baserl in ['CQL', 'IQL']:
            self.DR = False 
        self.cql_alpha = cql_alpha
        self.cql_n_actions = cql_n_actions
        self.iql_tau = iql_tau
        self.iql_beta = iql_beta

        self.actor_update_interval = actor_update_interval
        self.critic_updates = 0
        
        self.recover = recover
        self.action_diff = action_diff
        self.input_x = input_x
        self.encoder = HistoryEncoder(
            input_dim=input_dim, output_dim=output_dim, treatment_dim=treatment_dim,
            static_dim=static_dim, hiddens_enc=self.hiddens_enc, use_attention=use_attention,
            num_heads=num_heads, use_spectral_norm=recover, input_x=self.input_x
        ).to(DEVICE)
        self.actor = Actor(
            state_dim=self.state_dim, action_dim=treatment_dim, hiddens_sac=self.hiddens_sac
        ).to(DEVICE)
        self.critic = EnhancedCritic(
            state_dim=self.state_dim, action_dim=treatment_dim, hiddens_sac=self.hiddens_sac
        ).to(DEVICE)
        self.critic_target = EnhancedCritic(
            state_dim=self.state_dim, action_dim=treatment_dim, hiddens_sac=self.hiddens_sac
        ).to(DEVICE)
        self.critic_target.load_state_dict(self.critic.state_dict())
        if self.baserl == 'IQL':
            self.value_net = ValueNetwork(self.state_dim, self.hiddens_sac).to(DEVICE)
            self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
        if self.recover:
            self.target_predictor = TargetPredictor(
                state_dim=hidden_dim, goal_dim=output_dim, hidden_dim=hidden_dim*2
            ).to(DEVICE)
            self.target_predictor_optimizer = optim.Adam(self.target_predictor.parameters(), lr=lr)
        if self.DR or self.action_diff:
            self.behavior_policy = BehaviorPolicy(
                state_dim=self.state_dim, action_dim=treatment_dim, hiddens_bhvr=self.hiddens_bhvr
            ).to(DEVICE)
            self.behavior_optimizer = optim.Adam(self.behavior_policy.parameters(), lr=lr)
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=lr)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        self.use_automatic_entropy = use_automatic_entropy
        self.target_entropy = -float(treatment_dim)
        if use_automatic_entropy:
            self.log_alpha = torch.zeros(1, requires_grad=True, device=DEVICE)
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
            self.alpha = self.log_alpha.exp().item()
        min_history_length, max_history_length = config['model']['her_params']['min_history_length'], config['model']['her_params']['max_history_length']
        future_length, target_hit_ratio = config['model']['her_params']['future_length'], config['model']['her_params']['target_hit_ratio']
        optimal_threshold = search_reward_threshold_adaptive(
            dataset_collection, min_history_length=min_history_length, max_history_length=max_history_length,
            capacity=buffer_size, k_future=k_future, future_length=future_length, reward_mode=reward_mode,
            initial_threshold=goal_threshold, target_hit_ratio=target_hit_ratio, max_iter=10, hit_tolerance=0.01, max_high_limit=1
        )
        self.memory = HERReplayBuffer(buffer_size, k_future=k_future, reward_threshold=optimal_threshold, reward_mode=reward_mode)
        if use_amp and torch.cuda.is_available(): self.scaler = torch.cuda.amp.GradScaler()
        self.train_info = {
            'critic_losses': [], 'actor_losses': [], 'alpha_losses': [], 'alphas': [],
            'behavior_losses': [], 'target_losses': [], 'dynamics_losses': [], 'action_diffs': [],
            'value_losses': [] 
        }
        
        self.total_steps = 0
        if self.baserl != 'None':
            self.algorithm = self.baserl
        else:
            self.algorithm = "SAC-DR" if self.DR else "SAC"
        
        if self.recover: self.algorithm += "-Recover"
        if self.action_diff: self.algorithm += "-ActionDiff"

    def update_parameters(self):
        """更新网络参数"""
        batch = self.memory.sample(self.batch_size)
        if batch is None: return None
        
        history_batch, actions, rewards, next_history_batch, goals, dones = self._process_batch(batch)
        return self._update_standard(history_batch, actions, rewards, next_history_batch, goals, dones)
    
    def _update_standard(self, history_batch, actions, rewards, next_history_batch, goals, dones):
        """标准更新，包含SAC/DR, CQL, IQL的逻辑"""
        state = self.encoder(history_batch, goals)
        next_state = self.encoder(next_history_batch, goals)
        behavior_loss = torch.tensor(0.0, device=DEVICE)
        if self.DR or self.action_diff:
            log_probs = self.behavior_policy(state.detach(), actions)
            behavior_loss = -log_probs.mean()
            self.behavior_optimizer.zero_grad()
            behavior_loss.backward()
            self.behavior_optimizer.step()
            state = self.encoder(history_batch, goals) 
        critic_loss = torch.tensor(0.0, device=DEVICE)
        value_loss = torch.tensor(0.0, device=DEVICE)
        if self.baserl == 'IQL':
            with torch.no_grad():
                q1_target, q2_target = self.critic_target(state, actions)
                target_q_for_v = torch.min(q1_target, q2_target)
            v = self.value_net(state)
            value_loss = expectile_loss(target_q_for_v - v, self.iql_tau)
            self.value_optimizer.zero_grad(set_to_none=True)
            value_loss.backward()
            self.value_optimizer.step()
            with torch.no_grad():
                next_v = self.value_net(next_state)
                target_q = rewards + (1 - dones) * self.discount * next_v
            
            current_q1, current_q2 = self.critic(state, actions)
            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

        else: 
            with torch.no_grad():
                next_actions, next_log_probs = self.actor(next_state)
                target_q1, target_q2 = self.critic_target(next_state, next_actions)
                target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_probs

                if self.DR:
                    policy_log_probs = self.actor.log_prob(state, actions)
                    behavior_log_probs = self.behavior_policy(state, actions)
                    rho = torch.exp(policy_log_probs - behavior_log_probs).clamp(self.epsilon_1, self.epsilon_2)
                    target_q = rewards * rho + (1 - dones) * self.discount * target_q
                else:
                    target_q = rewards + (1 - dones) * self.discount * target_q
            
            current_q1, current_q2 = self.critic(state, actions)
            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

            if self.baserl == 'CQL':
                random_actions = torch.rand(self.batch_size, self.cql_n_actions, self.treatment_dim, device=DEVICE)
                state_expanded = state.unsqueeze(1).repeat(1, self.cql_n_actions, 1)
                
                with torch.no_grad():
                    policy_actions, _ = self.actor(state_expanded.reshape(-1, state.shape[-1]))
                    policy_actions = policy_actions.reshape(self.batch_size, self.cql_n_actions, self.treatment_dim)

                q1_rand, q2_rand = self.critic(state_expanded.reshape(-1, state.shape[-1]), random_actions.reshape(-1, self.treatment_dim))
                q1_pi, q2_pi = self.critic(state_expanded.reshape(-1, state.shape[-1]), policy_actions.reshape(-1, self.treatment_dim))

                cat_q1 = torch.cat([q1_rand.reshape(self.batch_size, -1), q1_pi.reshape(self.batch_size, -1)], dim=1)
                cat_q2 = torch.cat([q2_rand.reshape(self.batch_size, -1), q2_pi.reshape(self.batch_size, -1)], dim=1)

                lse_q1 = torch.logsumexp(cat_q1, dim=1, keepdim=True)
                lse_q2 = torch.logsumexp(cat_q2, dim=1, keepdim=True)
                
                cql_loss = ((lse_q1 - current_q1).mean() + (lse_q2 - current_q2).mean()) * self.cql_alpha
                critic_loss += cql_loss
        self.encoder_optimizer.zero_grad(set_to_none=True)
        self.critic_optimizer.zero_grad(set_to_none=True)
        critic_loss.backward()
        self.critic_optimizer.step()
        self.encoder_optimizer.step()
        self.critic_updates += 1
        actor_loss = torch.tensor(0.0, device=DEVICE)
        alpha_loss = torch.tensor(0.0, device=DEVICE)
        action_diff_value = torch.tensor(0.0, device=DEVICE)

        if self.critic_updates % self.actor_update_interval == 0:
            state_det = self.encoder(history_batch, goals).detach()
            
            if self.baserl == 'IQL':
                with torch.no_grad():
                    q1, q2 = self.critic_target(state_det, actions)
                    q = torch.min(q1, q2)
                    v = self.value_net(state_det)
                    adv = q - v
                exp_adv = torch.exp(self.iql_beta * adv).clamp(max=100.0)
                policy_log_probs = self.actor.log_prob(state_det, actions)
                actor_loss = -(exp_adv * policy_log_probs).mean()
            else:
                new_actions, log_probs = self.actor(state_det)
                q1, q2 = self.critic(state_det, new_actions)
                q = torch.min(q1, q2)
                actor_loss = (self.alpha * log_probs - q).mean()

            if self.action_diff:
                with torch.no_grad(): behavior_log_probs = self.behavior_policy(state_det, new_actions)
                action_diff_value = (log_probs - behavior_log_probs).mean()
                diff_weight = min(0.01, 1.0 * np.exp(-0.001 * self.total_steps))
                actor_loss += diff_weight * action_diff_value
            
            self.actor_optimizer.zero_grad(set_to_none=True)
            actor_loss.backward()
            self.actor_optimizer.step()
            if self.baserl != 'IQL' and self.use_automatic_entropy:
                alpha_loss = - (self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.alpha = self.log_alpha.exp().item()
            for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - self.beta) + param.data * self.beta)
        self.train_info['critic_losses'].append(critic_loss.item())
        self.train_info['actor_losses'].append(actor_loss.item())
        self.train_info['behavior_losses'].append(behavior_loss.item())
        if self.baserl != 'IQL':
            self.train_info['alpha_losses'].append(alpha_loss.item())
            self.train_info['alphas'].append(self.alpha)
        if self.baserl == 'IQL':
            self.train_info['value_losses'].append(value_loss.item())
        if self.action_diff:
            self.train_info['action_diffs'].append(action_diff_value.item())

        self.total_steps += 1
        return {
            'critic_loss': critic_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha_loss': alpha_loss.item(),
            'alpha': self.alpha if self.baserl != 'IQL' else 0.0,
            'behavior_loss': behavior_loss.item(),
            'value_loss': value_loss.item() if self.baserl == 'IQL' else 0.0,
            'action_diff': action_diff_value.item() if self.action_diff else 0.0
        }

    def _process_batch(self, batch):
        """处理批次数据，转换为张量"""
        if isinstance(batch, tuple) and len(batch) == 3:  
            batch, _, _ = batch
            
        history_dicts, actions, rewards, next_history_dicts, goals, dones = [], [], [], [], [], []
        
        for exp in batch:
            history_dicts.append(exp.history_dict)
            actions.append(exp.action)
            rewards.append(exp.reward)
            next_history_dicts.append(exp.next_history_dict)
            goals.append(exp.goal)
            dones.append(exp.done)
        actions = torch.FloatTensor(np.array(actions)).to(DEVICE)
        rewards = torch.FloatTensor(np.array(rewards)).reshape(-1, 1).to(DEVICE)
        goals = torch.FloatTensor(np.array(goals)).to(DEVICE)
        dones = torch.FloatTensor(np.array(dones)).reshape(-1, 1).to(DEVICE)
        
        processed_history = self._process_history_batch(history_dicts)
        processed_next_history = self._process_history_batch(next_history_dicts)
        
        return processed_history, actions, rewards, processed_next_history, goals, dones
    
    def _process_history_batch(self, history_dicts):
        """处理历史字典批次，处理变长序列"""
        batch_size = len(history_dicts)
        outputs_list, static_list, treatments_list, vitals_list = [], [], [], []
        seq_lengths = []
        
        for d in history_dicts:
            outputs, static_features, current_treatments = d['outputs'], d['static_features'], d['current_treatments']
            seq_lengths.append(outputs.shape[1])
            outputs_list.append(torch.FloatTensor(outputs[0]))
            static_list.append(torch.FloatTensor(static_features[0]))
            treatments_list.append(torch.FloatTensor(current_treatments[0]))
            if self.input_x: vitals_list.append(torch.FloatTensor(d['vitals'][0]))
        padded_outputs = pad_sequence(outputs_list, batch_first=True)
        padded_static = pad_sequence(static_list, batch_first=True)
        padded_treatments = pad_sequence(treatments_list, batch_first=True)
        
        batch_history = {
            'outputs': padded_outputs.to(DEVICE),
            'static_features': padded_static.to(DEVICE),
            'current_treatments': padded_treatments.to(DEVICE),
            'seq_lengths': torch.LongTensor(seq_lengths).to(DEVICE)
        }
        if self.input_x:
            padded_vitals = pad_sequence(vitals_list, batch_first=True)
            batch_history['vitals'] = padded_vitals.to(DEVICE)
        
        return batch_history
    
    def select_action(self, history_dict, goal, evaluate=False):
        """选择单步动作"""
        history_tensor_dict = {k: torch.FloatTensor(v).to(DEVICE) if isinstance(v, np.ndarray) else v for k, v in history_dict.items()}
        goal_tensor = torch.FloatTensor(goal).to(DEVICE) if isinstance(goal, np.ndarray) else goal.to(DEVICE)
        
        with torch.no_grad():
            state = self.encoder(history_tensor_dict, goal_tensor)
            action = self.actor(state, deterministic=True) if evaluate else self.actor(state, deterministic=False)[0]
        return action.cpu().numpy()

    def generate_treatment_plan(self, history_dict, goal, dataset_collection, future_dict, future_length=None, early_stop=True):
        """
        逐步生成治疗计划，每步使用simulate_output_after_actions更新状态
        
        参数:
            history_dict: 历史状态字典
            goal: 目标状态
            dataset_collection: 包含simulate_output_after_actions的数据集集合
            future_length: 计划长度，默认使用self.future_length
            early_stop: 是否在达到目标时提前停止
            
        返回:
            actions: 生成的动作序列
            outputs: 每步的输出结果
            steps_taken: 实际执行的步数
        """
        if future_length is None:
            future_length = self.future_length
        self.actor.eval()
        self.encoder.eval()
        history_tensor_dict = {}
        for key in history_dict:
            if isinstance(history_dict[key], np.ndarray):
                history_tensor_dict[key] = torch.FloatTensor(history_dict[key]).to(DEVICE)
            else:
                history_tensor_dict[key] = history_dict[key]
                
        goal_tensor = torch.FloatTensor(goal).to(DEVICE) if isinstance(goal, np.ndarray) else goal.to(DEVICE)
        goal_np = goal if isinstance(goal, np.ndarray) else goal.cpu().numpy()
        actions = []
        outputs = []
        updated_history = {}
        for key in history_dict:
            if isinstance(history_dict[key], np.ndarray):
                updated_history[key] = history_dict[key].copy()
            else:
                updated_history[key] = history_dict[key]
        current_history = history_tensor_dict
        steps_taken = 0
        for t in range(future_length):
            with torch.no_grad():
                state = self.encoder(current_history, goal_tensor)
                action = self.actor(state, deterministic=True).unsqueeze(1)
                action_np = action.cpu().numpy()
                actions.append(action_np)
                
                if len(actions) == 1:
                    actions_tensor = actions[0]
                else:
                    actions_tensor = np.concatenate(actions, axis=1) 
                     
                actions_tensor = torch.FloatTensor(actions_tensor).to(DEVICE)
                output = dataset_collection.val_f.simulate_output_after_actions(
                    history_dict,
                    actions_tensor,
                    dataset_collection.train_scaling_params,
                )
                
                outputs.append(output)
                steps_taken += 1
                if early_stop and np.linalg.norm(output - goal_np) < self.goal_threshold * 0.001:
                    break
                if t == future_length - 1:
                    break
                if 'current_treatments' in updated_history:
                    prev_treatments = updated_history['current_treatments']
                    seq_len = prev_treatments.shape[1]
                    new_treatments = np.zeros((prev_treatments.shape[0], seq_len + 1, prev_treatments.shape[2]))
                    new_treatments[:, :-1, :] = prev_treatments
                    new_treatments[:, -1, :] = action_np
                    updated_history['current_treatments'] = new_treatments
                
                if 'outputs' in updated_history:
                    prev_outputs = updated_history['outputs']
                    seq_len = prev_outputs.shape[1]
                    new_outputs = np.zeros((prev_outputs.shape[0], seq_len + 1, prev_outputs.shape[2]))
                    new_outputs[:, :-1, :] = prev_outputs
                    new_outputs[:, -1, :] = output
                    updated_history['outputs'] = new_outputs
                
                if 'static_features' in updated_history:
                    static_features = updated_history['static_features']
                    seq_len = static_features.shape[1]
                    new_len = updated_history['outputs'].shape[1]
                    if new_len > seq_len:
                        new_static = np.zeros((static_features.shape[0], new_len, static_features.shape[2]))
                        new_static[:, :seq_len, :] = static_features
                        for i in range(seq_len, new_len):
                            new_static[:, i, :] = static_features[:, -1, :]
                        updated_history['static_features'] = new_static

                if 'vitals' in updated_history:
                    prev_vitals = updated_history['vitals']
                    seq_len = prev_vitals.shape[1]
                    new_vitals = np.zeros((prev_vitals.shape[0], seq_len + 1, prev_vitals.shape[2]))
                    new_vitals[:, :-1, :] = prev_vitals
                    new_vitals[:, -1, :] = updated_history['future_vitals'][:, :1, :]
                    updated_history['vitals'] = new_vitals
                
                if 'future_vitals' in updated_history:
                    updated_history['future_vitals'] = updated_history['future_vitals'][:, 1:, :]
                current_history = {k: torch.FloatTensor(v).to(DEVICE) if isinstance(v, np.ndarray) else v for k, v in updated_history.items()}
        self.actor.train()
        self.encoder.train()
        
        actions = np.array(actions)
        outputs = np.array(outputs) if outputs else np.array([])
        
        return actions, outputs, steps_taken
    
    def generate_treatment_plan_batch(self, history_dict_batch, goal_batch, dataset_collection, future_dict_batch, future_length=None, early_stop=True):
        """
        批量版本的生成治疗计划函数
        """
        if future_length is None:
            future_length = self.future_length

        self.actor.eval()
        self.encoder.eval()

        batch_size = len(history_dict_batch)
        device = DEVICE

        H_t_batch = {}
        for history_dict in history_dict_batch:
            for key, value in history_dict.items():
                if key not in H_t_batch: H_t_batch[key] = []
                H_t_batch[key].append(torch.FloatTensor(value) if isinstance(value, np.ndarray) else value)

        for key in H_t_batch:
            if isinstance(H_t_batch[key][0], torch.Tensor):
                H_t_batch[key] = torch.cat(H_t_batch[key], dim=0).to(device)

        goal_tensor_batch = [torch.FloatTensor(g).unsqueeze(0) if isinstance(g, np.ndarray) else (g.unsqueeze(0) if g.dim() == 1 else g) for g in goal_batch]
        goal_tensor_batch = torch.cat(goal_tensor_batch, dim=0).to(device)
        goal_np_batch = [g if isinstance(g, np.ndarray) else g.cpu().numpy() for g in goal_batch]

        updated_history_batch = []
        for history_dict in history_dict_batch:
            updated_history = {k: v.copy() if isinstance(v, np.ndarray) else (v.cpu().numpy().copy() if hasattr(v, 'cpu') else v) for k, v in history_dict.items()}
            updated_history_batch.append(updated_history)

        current_H_t_batch = H_t_batch
        actions_batch = [[] for _ in range(batch_size)]
        outputs_batch = [[] for _ in range(batch_size)]
        steps_taken_batch = [future_length] * batch_size

        for t in range(future_length):
            with torch.no_grad():
                state_batch = self.encoder(current_H_t_batch, goal_tensor_batch)
                action_batch = self.actor(state_batch, deterministic=True)
                
                for i in range(batch_size):
                    action_np = action_batch[i:i+1].cpu().numpy().reshape(1, 1, -1)
                    actions_batch[i].append(action_np)

                    actions_tensor = np.concatenate(actions_batch[i], axis=1)
                    actions_tensor = torch.FloatTensor(actions_tensor).to(device)

                    output = dataset_collection.val_f.simulate_output_after_actions(
                        history_dict_batch[i], actions_tensor, dataset_collection.train_scaling_params
                    )
                    outputs_batch[i].append(output)

                    if early_stop and np.linalg.norm(output - goal_np_batch[i]) < self.goal_threshold * 0.001 and steps_taken_batch[i] > t + 1:
                        steps_taken_batch[i] = t + 1

                if t == future_length - 1: break

                for i in range(batch_size):
                    action_np = action_batch[i:i+1].cpu().numpy().reshape(1, 1, -1)
                    output = outputs_batch[i][-1]
                    updated_history_batch[i] = self._update_patient_history(updated_history_batch[i], action_np, output)

                current_H_t_batch = {}
                for key in updated_history_batch[0]:
                    batch_data = [torch.FloatTensor(h[key]) if isinstance(h[key], np.ndarray) else h[key] for h in updated_history_batch]
                    if isinstance(batch_data[0], torch.Tensor):
                        current_H_t_batch[key] = torch.cat(batch_data, dim=0).to(device)

        self.actor.train()
        self.encoder.train()

        actions_batch = [np.array(actions).squeeze(1) if actions else np.array([]) for actions in actions_batch]
        outputs_batch = [np.array(outputs) if outputs else np.array([]) for outputs in outputs_batch]

        return actions_batch, outputs_batch, steps_taken_batch

    def _update_patient_history(self, updated_history, action_np, output):
        """
        更新患者历史记录
        """
        for key, value in {'current_treatments': action_np, 'outputs': output}.items():
            if key in updated_history:
                prev = updated_history[key]
                new_val = np.zeros((prev.shape[0], prev.shape[1] + 1, prev.shape[2]))
                new_val[:, :-1, :] = prev
                new_val[:, -1, :] = value.reshape(prev.shape[0], -1)
                updated_history[key] = new_val

        if 'static_features' in updated_history:
            static = updated_history['static_features']
            new_len = updated_history['outputs'].shape[1]
            if new_len > static.shape[1]:
                new_static = np.zeros((static.shape[0], new_len, static.shape[2]))
                new_static[:, :static.shape[1], :] = static
                new_static[:, static.shape[1]:, :] = static[:, -1:, :]
                updated_history['static_features'] = new_static

        if 'vitals' in updated_history:
            prev_vitals = updated_history['vitals']
            new_vitals = np.zeros((prev_vitals.shape[0], prev_vitals.shape[1] + 1, prev_vitals.shape[2]))
            new_vitals[:, :-1, :] = prev_vitals
            if 'future_vitals' in updated_history and updated_history['future_vitals'].shape[1] > 0:
                new_vitals[:, -1, :] = updated_history['future_vitals'][:, :1, :]
                updated_history['future_vitals'] = updated_history['future_vitals'][:, 1:, :]
            updated_history['vitals'] = new_vitals

        return updated_history

    def train_offline(self, iterations, progress_interval=100, eval_interval=5000, dataset_collection=None):
        """离线训练"""
        print(f"\n开始离线训练 ({self.algorithm})...")
        train_start = time.time()
        
        losses = []
        eval_results = []
        best_rmse = float('inf')
        best_iter = 0
        
        for i in range(iterations):
            loss_info = self.update_parameters()
            if loss_info is not None: losses.append(loss_info)
            
            if (i+1) % progress_interval == 0:
                recent_losses = losses[-progress_interval:]
                avg_critic_loss = np.mean([l['critic_loss'] for l in recent_losses])
                avg_actor_loss = np.mean([l['actor_loss'] for l in recent_losses])
                
                print_str = f"迭代 {i+1}/{iterations}, Critic损失: {avg_critic_loss:.4f}, Actor损失: {avg_actor_loss:.4f}"
                
                if self.baserl != 'IQL':
                    avg_alpha = np.mean([l['alpha'] for l in recent_losses])
                    print_str += f", Alpha: {avg_alpha:.4f}"
                if self.DR:
                    avg_behavior_loss = np.mean([l['behavior_loss'] for l in recent_losses])
                    print_str += f", Behavior损失: {avg_behavior_loss:.4f}"
                if self.baserl == 'IQL':
                    avg_value_loss = np.mean([l['value_loss'] for l in recent_losses])
                    print_str += f", Value损失: {avg_value_loss:.4f}"
                if self.action_diff:
                    avg_action_diff = np.mean([l['action_diff'] for l in recent_losses])
                    print_str += f", Action差异: {avg_action_diff:.4f}"
                    
                print(print_str)
            
            if dataset_collection is not None and (i+1) % eval_interval == 0:
                print(f"\n评估模型 (迭代 {i+1})...")
                metrics = evaluate_agent(self, dataset_collection, num_episodes=100)
                eval_results.append((i+1, metrics))
                
                if metrics['avg_rmse'] < best_rmse:
                    best_rmse = metrics['avg_rmse']
                    best_iter = i+1
                    self.save(f"best_model_{self.algorithm}.pth")
                    
                print(f"当前最佳RMSE: {best_rmse:.6f} (迭代 {best_iter})")
        
        train_time = time.time() - train_start
        print(f"训练完成，用时: {train_time:.1f}秒")
        print(f"最佳RMSE: {best_rmse:.6f} (迭代 {best_iter})")
        
        return losses, eval_results
    
    def save(self, path):
        """保存模型"""
        save_dict = {
            'encoder': self.encoder.state_dict(),
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'critic_target': self.critic_target.state_dict(),
            'log_alpha': self.log_alpha if self.use_automatic_entropy and self.baserl != 'IQL' else None,
            'alpha': self.alpha,
            'train_info': self.train_info,
            'total_steps': self.total_steps,
            'algorithm': self.algorithm
        }
        
        if self.DR or self.action_diff:
            save_dict['behavior_policy'] = self.behavior_policy.state_dict()
        if self.recover:
            save_dict['target_predictor'] = self.target_predictor.state_dict()
        if self.baserl == 'IQL':
            save_dict['value_net'] = self.value_net.state_dict()
            
        torch.save(save_dict, path)
        print(f"模型保存到 {path}")
    
    def load(self, path):
        """加载模型"""
        checkpoint = torch.load(path, map_location=DEVICE)
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])
        self.critic_target.load_state_dict(checkpoint['critic_target'])
        
        if 'behavior_policy' in checkpoint and (self.DR or self.action_diff):
            self.behavior_policy.load_state_dict(checkpoint['behavior_policy'])
        if self.recover and 'target_predictor' in checkpoint:
            self.target_predictor.load_state_dict(checkpoint['target_predictor'])
        if self.baserl == 'IQL' and 'value_net' in checkpoint:
            self.value_net.load_state_dict(checkpoint['value_net'])
            
        if self.use_automatic_entropy and checkpoint.get('log_alpha') is not None and self.baserl != 'IQL':
            self.log_alpha.data.copy_(checkpoint['log_alpha'].data)
            
        self.alpha = checkpoint['alpha']
        self.train_info = checkpoint['train_info']
        self.total_steps = checkpoint['total_steps']
        if 'algorithm' in checkpoint: self.algorithm = checkpoint['algorithm']
        print(f"从 {path} 加载模型 ({self.algorithm})")

