import copy
import time
import random
from collections import deque

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from src.gift.models.encoders import HistoryEncoder
from src.gift.models.networks import Actor
from src.gift.buffers.her_buffer import HERReplayBuffer, search_reward_threshold_adaptive

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _build_mlp(input_dim, output_dim, hiddens, use_layer_norm=True):
    """Build an MLP with optional LayerNorm, finishing with a linear projection."""
    layers = []
    in_dim = input_dim
    for hidden_dim in hiddens:
        layers.append(nn.Linear(in_dim, hidden_dim))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim))
        layers.append(nn.ReLU())
        in_dim = hidden_dim

    layers.append(nn.Linear(in_dim, output_dim))
    return nn.Sequential(*layers)


def _apply_cold_start_init(module):
    """
    Apply SCRL cold-start initialisation (very small weights) to the *final* layer
    of an MLP or directly to a Linear module.
    """
    if isinstance(module, nn.Sequential):
        last = module[-1]
        if isinstance(last, nn.Linear):
            last.weight.data.uniform_(-1e-12, 1e-12)
            if last.bias is not None:
                last.bias.data.uniform_(-1e-12, 1e-12)
    elif isinstance(module, nn.Linear):
        module.weight.data.uniform_(-1e-12, 1e-12)
        if module.bias is not None:
            module.bias.data.uniform_(-1e-12, 1e-12)


def _xavier_init_if_linear(m):
    """Fallback Xavier initialisation for non-final layers."""
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=1.0)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
class _HistoryGoalEncoder(nn.Module):
    """
    Wraps the existing HistoryEncoder so we can obtain separate history / goal embeddings.
    """
    def __init__(self, config, input_dim, output_dim, treatment_dim, static_dim, use_attention, num_heads):
        super().__init__()
        print(f"input_dim:{input_dim}")
        self.encoder_base = HistoryEncoder(
            input_dim=input_dim,
            output_dim=output_dim,
            treatment_dim=treatment_dim,
            static_dim=static_dim,
            hiddens_enc=config.model.hiddens_enc,
            use_attention=use_attention,
            num_heads=num_heads,
            use_spectral_norm=False,
            input_x=config.dataset.get('input_x', False)
        ).to(DEVICE)

        if hasattr(self.encoder_base, "lstm") and hasattr(self.encoder_base.lstm, "hidden_size"):
            self.history_dim = self.encoder_base.lstm.hidden_size
        else:
            self.history_dim = config.model.hiddens_enc[-1]

        goal_encoder_last = getattr(self.encoder_base, "goal_encoder", None)
        if goal_encoder_last is not None and isinstance(goal_encoder_last, nn.Sequential):
            last_layer = goal_encoder_last[-1]
            if isinstance(last_layer, nn.Linear):
                self.goal_dim = last_layer.out_features
            else:
                self.goal_dim = config.model.hiddens_enc[-1]
        else:
            self.goal_dim = config.model.hiddens_enc[-1]

    def encode_history(self, history_batch):
        outputs = history_batch['outputs']
        static_features = history_batch['static_features']
        current_treatments = history_batch['current_treatments']
        if static_features.dim() == 2:
            seq_len = outputs.size(1)
            static_features_expanded = static_features.unsqueeze(1).expand(-1, seq_len, -1)
        else:
            static_features_expanded = static_features

        combined_features = torch.cat([outputs, static_features_expanded, current_treatments], dim=2)
        
        if self.encoder_base.input_x:
            current_vitals = history_batch['vitals']
            combined_features = torch.cat([combined_features, current_vitals], dim=2)

        if 'seq_lengths' in history_batch and history_batch['seq_lengths'] is not None:
            seq_lengths = history_batch['seq_lengths']
            packed_features = nn.utils.rnn.pack_padded_sequence(
                combined_features, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            _, (h_n, _) = self.encoder_base.lstm(packed_features)
        else:
            _, (h_n, _) = self.encoder_base.lstm(combined_features)

        history_encoding = h_n[-1]
        return history_encoding

    def encode_goal(self, goal_batch):
        if isinstance(goal_batch, np.ndarray):
            goal_batch = torch.as_tensor(goal_batch, dtype=torch.float32, device=DEVICE)
        elif isinstance(goal_batch, torch.Tensor):
            goal_batch = goal_batch.to(DEVICE)
        else:
            raise TypeError("Unsupported type for goal_batch; expected np.ndarray or torch.Tensor.")

        if goal_batch.dim() == 1:
            goal_batch = goal_batch.unsqueeze(0)

        goal_encoding = self.encoder_base.goal_encoder(goal_batch)
        return goal_encoding

    def forward(self, history_batch=None, goal_batch=None):
        history_encoding = self.encode_history(history_batch) if history_batch is not None else None
        goal_encoding = self.encode_goal(goal_batch) if goal_batch is not None else None
        return history_encoding, goal_encoding
class _SCRL_Critic(nn.Module):
    """
    SCRL critic: f(s, a, g) = <phi(s, a), psi(g)> with cosine normalisation.
    """
    def __init__(self, history_dim, goal_dim, action_dim, hiddens_contrast, representation_dim, use_layer_norm=True):
        super().__init__()

        self.representation_dim = representation_dim

        self.phi_network = _build_mlp(
            input_dim=history_dim + action_dim,
            output_dim=representation_dim,
            hiddens=hiddens_contrast,
            use_layer_norm=use_layer_norm
        ).apply(_xavier_init_if_linear)
        _apply_cold_start_init(self.phi_network)

        self.psi_network = _build_mlp(
            input_dim=goal_dim,
            output_dim=representation_dim,
            hiddens=hiddens_contrast,
            use_layer_norm=use_layer_norm
        ).apply(_xavier_init_if_linear)
        _apply_cold_start_init(self.psi_network)

    def encode_phi(self, history_encoding, action):
        phi_input = torch.cat([history_encoding, action], dim=1)
        phi = self.phi_network(phi_input)
        return phi

    def encode_psi(self, goal_encoding):
        psi = self.psi_network(goal_encoding)
        return psi

    def forward(self, history_encoding, action, goal_encoding):
        phi = self.encode_phi(history_encoding, action)
        psi = self.encode_psi(goal_encoding)
        phi_norm = F.normalize(phi, p=2, dim=1)
        psi_norm = F.normalize(psi, p=2, dim=1)
        logits = (phi_norm * psi_norm).sum(dim=-1, keepdim=True)
        return logits
class SCRL_Agent:
    """
    Stable Contrastive Reinforcement Learning (SCRL) agent with HER replay on medical data.
    """
    def __init__(
        self,
        dataset_collection,
        config,
        input_dim=1,
        output_dim=1,
        treatment_dim=2,
        static_dim=1,
        future_length=5,
        discount=0.99,
        lr=3e-4,
        buffer_size=200000,
        batch_size=2048,
        goal_threshold=1e-3,
        k_future=10,
        use_amp=False,
        reward_mode='combined',
        use_attention=False,
        num_heads=4,
        bc_reg_lambda=0.1,
        use_data_aug=True,
        temperature=0.1,
        queue_size=65536,
        encoder_grad_clip=1.0,
        actor_grad_clip=1.0,
        critic_grad_clip=1.0
    ):
        self.hiddens_enc = config.model.hiddens_enc
        self.hiddens_sac = config.model.hiddens_sac

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.treatment_dim = treatment_dim
        self.static_dim = static_dim
        self.future_length = future_length
        self.discount = discount
        self.batch_size = batch_size
        self.goal_threshold = goal_threshold
        self.use_amp = use_amp
        self.bc_reg_lambda = bc_reg_lambda
        self.use_data_aug = use_data_aug
        self.data_aug_prob = 0.5
        self.data_aug_noise_std = 0.01
        self.data_aug_time_pad = 4

        self.encoder = _HistoryGoalEncoder(
            config=config,
            input_dim=input_dim,
            output_dim=output_dim,
            treatment_dim=treatment_dim,
            static_dim=static_dim,
            use_attention=use_attention,
            num_heads=num_heads
        ).to(DEVICE)

        self.history_dim = self.encoder.history_dim
        self.goal_dim = self.encoder.goal_dim

        representation_dim = config.model.representation_dim
        self.critic = _SCRL_Critic(
            history_dim=self.history_dim,
            goal_dim=self.goal_dim,
            action_dim=treatment_dim,
            hiddens_contrast=self.hiddens_sac,
            representation_dim=representation_dim,
            use_layer_norm=True
        ).to(DEVICE)

        actor_state_dim = self.history_dim + self.goal_dim
        self.actor = Actor(
            state_dim=actor_state_dim,
            action_dim=treatment_dim,
            hiddens_sac=self.hiddens_sac
        ).to(DEVICE)

        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)

        optimal_threshold = search_reward_threshold_adaptive(
            dataset_collection,
            min_history_length=config['model']['her_params']['min_history_length'],
            max_history_length=config['model']['her_params']['max_history_length'],
            capacity=buffer_size,
            k_future=k_future,
            future_length=future_length,
            reward_mode=reward_mode,
            initial_threshold=goal_threshold,
            target_hit_ratio=config['model']['her_params']['target_hit_ratio'],
            max_iter=10,
            hit_tolerance=0.01,
            max_high_limit=1
        )
        self.memory = HERReplayBuffer(
            buffer_size,
            k_future=k_future,
            reward_threshold=optimal_threshold,
            reward_mode=reward_mode
        )

        if use_amp and torch.cuda.is_available():
            self.scaler = torch.cuda.amp.GradScaler()

        self.train_info = {'critic_losses': [], 'actor_losses': [], 'bc_losses': []}
        self.total_steps = 0
        self.algorithm = "SCRL"
        self.temperature = temperature
        self.representation_dim = representation_dim
        self.queue_size = max(queue_size, 0)
        if self.queue_size > 0:
            self.neg_queue = torch.zeros(self.queue_size, self.representation_dim, device=DEVICE)
            self.neg_queue_ptr = 0
            self.neg_queue_filled = 0
        else:
            self.neg_queue = None
            self.neg_queue_ptr = 0
            self.neg_queue_filled = 0

        self.encoder_grad_clip = encoder_grad_clip
        self.actor_grad_clip = actor_grad_clip
        self.critic_grad_clip = critic_grad_clip
    def _enqueue_negatives(self, embeddings):
        if self.neg_queue is None or embeddings.numel() == 0:
            return
        embeddings = embeddings.detach()
        batch_size = embeddings.shape[0]

        if batch_size >= self.queue_size:
            self.neg_queue.copy_(embeddings[-self.queue_size:])
            self.neg_queue_ptr = 0
            self.neg_queue_filled = self.queue_size
            return

        end_ptr = self.neg_queue_ptr + batch_size
        if end_ptr <= self.queue_size:
            self.neg_queue[self.neg_queue_ptr:end_ptr] = embeddings
        else:
            first_len = self.queue_size - self.neg_queue_ptr
            self.neg_queue[self.neg_queue_ptr:] = embeddings[:first_len]
            remaining = end_ptr - self.queue_size
            self.neg_queue[:remaining] = embeddings[first_len:]
        self.neg_queue_ptr = (self.neg_queue_ptr + batch_size) % self.queue_size
        self.neg_queue_filled = min(self.neg_queue_filled + batch_size, self.queue_size)

    def _random_temporal_shift(self, tensor):
        if tensor.dim() != 3:
            return tensor.clone()
        seq_len = tensor.shape[1]
        if seq_len <= 1 or self.data_aug_time_pad <= 0:
            return tensor.clone()
        pad = min(self.data_aug_time_pad, seq_len)
        left = tensor[:, :1, :].repeat(1, pad, 1)
        right = tensor[:, -1:, :].repeat(1, pad, 1)
        padded = torch.cat([left, tensor, right], dim=1)
        max_offset = pad * 2
        offsets = torch.randint(0, max_offset + 1, (tensor.shape[0],), device=tensor.device)
        base_idx = torch.arange(seq_len, device=tensor.device).unsqueeze(0)
        gather_idx = base_idx + offsets.unsqueeze(1)
        gather_idx = gather_idx.clamp_(0, padded.shape[1] - 1)
        gather_idx = gather_idx.unsqueeze(-1).expand(-1, -1, tensor.shape[2])
        augmented = torch.gather(padded, 1, gather_idx)
        return augmented

    def _augment(self, batch_data, apply_aug=True):
        if not apply_aug:
            return batch_data
        if isinstance(batch_data, dict):
            aug_batch = {}
            for key, value in batch_data.items():
                if key == 'seq_lengths':
                    aug_batch[key] = value
                    continue
                if torch.is_tensor(value):
                    if value.dim() >= 3:
                        aug_batch[key] = self._random_temporal_shift(value)
                    else:
                        noise = torch.randn_like(value) * self.data_aug_noise_std
                        aug_batch[key] = torch.clamp(value + noise, -1.0, 1.0)
                else:
                    aug_batch[key] = copy.deepcopy(value)
            return aug_batch
        if torch.is_tensor(batch_data):
            if batch_data.dim() >= 3:
                return self._random_temporal_shift(batch_data)
            noise = torch.randn_like(batch_data) * self.data_aug_noise_std
            return torch.clamp(batch_data + noise, -1.0, 1.0)
        if isinstance(batch_data, np.ndarray):
            noise = np.random.randn(*batch_data.shape).astype(batch_data.dtype) * self.data_aug_noise_std
            augmented = batch_data + noise
            np.clip(augmented, -1.0, 1.0, out=augmented)
            return augmented
        return copy.deepcopy(batch_data)

    def _get_last_output(self, history_dict):
        if 'outputs' in history_dict:
            return history_dict['outputs'][:, -1, :]
        if 'prev_outputs' in history_dict:
            return history_dict['prev_outputs'][:, -1, :]
        raise ValueError("No 'outputs' or 'prev_outputs' in history_dict")

    def _process_batch(self, batch):
        if isinstance(batch, tuple) and len(batch) == 3:
            batch, _, _ = batch

        history_dicts, actions, rewards, next_history_dicts, goals, dones = [], [], [], [], [], []
        
        for exp in batch:
            history_dicts.append(exp.history_dict)
            actions.append(exp.action)
            rewards.append(exp.reward)
            next_history_dicts.append(exp.next_history_dict)
            goals.append(exp.goal)
            dones.append(exp.done)
        actions = torch.as_tensor(np.array(actions), dtype=torch.float32, device=DEVICE)
        rewards = torch.as_tensor(np.array(rewards), dtype=torch.float32, device=DEVICE).reshape(-1, 1)
        goals = torch.as_tensor(np.array(goals), dtype=torch.float32, device=DEVICE)
        dones = torch.as_tensor(np.array(dones), dtype=torch.float32, device=DEVICE).reshape(-1, 1)

        processed_history = self._process_history_batch(history_dicts)
        processed_next_history = self._process_history_batch(next_history_dicts)

        return processed_history, actions, rewards, processed_next_history, goals, dones

    def _process_history_batch(self, history_dicts):
        outputs_list, static_list, treatments_list, vitals_list = [], [], [], []
        seq_lengths = []

        for d in history_dicts:
            outputs = torch.as_tensor(d['outputs'][0], dtype=torch.float32)
            static_features = torch.as_tensor(d['static_features'][0], dtype=torch.float32)
            current_treatments = torch.as_tensor(d['current_treatments'][0], dtype=torch.float32)
            seq_lengths.append(outputs.shape[0])
            outputs_list.append(outputs)
            static_list.append(static_features)
            treatments_list.append(current_treatments)

            if self.encoder.encoder_base.input_x:
                vitals = torch.as_tensor(d['vitals'][0], dtype=torch.float32)
                vitals_list.append(vitals)

        padded_outputs = pad_sequence(outputs_list, batch_first=True)
        padded_static = pad_sequence(static_list, batch_first=True)
        padded_treatments = pad_sequence(treatments_list, batch_first=True)

        batch_history = {
            'outputs': padded_outputs.to(DEVICE),
            'static_features': padded_static.to(DEVICE),
            'current_treatments': padded_treatments.to(DEVICE),
            'seq_lengths': torch.as_tensor(seq_lengths, dtype=torch.long, device=DEVICE)
        }

        if self.encoder.encoder_base.input_x:
            padded_vitals = pad_sequence(vitals_list, batch_first=True)
            batch_history['vitals'] = padded_vitals.to(DEVICE)

        return batch_history
    def update_parameters(self):
        batch = self.memory.sample(self.batch_size)
        if batch is None:
            return None

        history_batch, actions, rewards, next_history_batch, goals, dones = self._process_batch(batch)
        batch_size = actions.shape[0]
        self.encoder_optimizer.zero_grad(set_to_none=True)
        self.critic_optimizer.zero_grad(set_to_none=True)

        history_encoding = self.encoder.encode_history(history_batch)
        goal_encoding = self.encoder.encode_goal(goals)

        phi_embeddings = self.critic.encode_phi(history_encoding, actions)
        psi_pos_embeddings = self.critic.encode_psi(goal_encoding)

        phi_norm = F.normalize(phi_embeddings, dim=1)
        psi_pos_norm = F.normalize(psi_pos_embeddings, dim=1)

        logits_inbatch = torch.matmul(phi_norm, psi_pos_norm.t()) / self.temperature
        labels = torch.arange(batch_size, device=DEVICE)

        if self.neg_queue is not None and self.neg_queue_filled > 0:
            negatives = self.neg_queue[:self.neg_queue_filled]
            logits_queue = torch.matmul(phi_norm, negatives.t()) / self.temperature
            logits = torch.cat([logits_inbatch, logits_queue], dim=1)
        else:
            logits = logits_inbatch

        critic_loss = F.cross_entropy(logits, labels)

        critic_loss.backward()
        if self.encoder_grad_clip is not None:
            nn.utils.clip_grad_norm_(self.encoder.parameters(), self.encoder_grad_clip)
        if self.critic_grad_clip is not None:
            nn.utils.clip_grad_norm_(self.critic.parameters(), self.critic_grad_clip)

        self.encoder_optimizer.step()
        self.critic_optimizer.step()

        with torch.no_grad():
            self._enqueue_negatives(psi_pos_norm)
        self.actor_optimizer.zero_grad(set_to_none=True)
        self.encoder_optimizer.zero_grad(set_to_none=True)

        history_encoding_actor = self.encoder.encode_history(history_batch)
        goal_encoding_actor = self.encoder.encode_goal(goals)

        actor_state = torch.cat([history_encoding_actor, goal_encoding_actor], dim=1)
        new_actions, log_prob_new = self.actor(actor_state, return_log_prob=True)

        phi_for_actor = self.critic.encode_phi(history_encoding_actor, new_actions)
        psi_for_actor = self.critic.encode_psi(goal_encoding_actor)
        critic_value = (F.normalize(phi_for_actor, dim=1) * F.normalize(psi_for_actor, dim=1)).sum(dim=-1, keepdim=True)

        actor_contrastive_loss = -critic_value.mean()

        apply_bc_aug = self.use_data_aug and (random.random() < self.data_aug_prob)
        if apply_bc_aug:
            history_batch_bc = self._augment(history_batch, apply_aug=True)
            goals_bc = self._augment(goals, apply_aug=True)
            history_encoding_bc = self.encoder.encode_history(history_batch_bc)
            goal_encoding_bc = self.encoder.encode_goal(goals_bc)
            actor_state_bc = torch.cat([history_encoding_bc, goal_encoding_bc], dim=1)
        else:
            actor_state_bc = actor_state

        bc_log_probs = self.actor.log_prob(actor_state_bc, actions)
        bc_loss = -bc_log_probs.mean()

        actor_loss = actor_contrastive_loss + self.bc_reg_lambda * bc_loss

        actor_loss.backward()
        if self.actor_grad_clip is not None:
            nn.utils.clip_grad_norm_(self.actor.parameters(), self.actor_grad_clip)
        if self.encoder_grad_clip is not None:
            nn.utils.clip_grad_norm_(self.encoder.parameters(), self.encoder_grad_clip)

        self.actor_optimizer.step()
        self.encoder_optimizer.step()

        self.train_info['critic_losses'].append(critic_loss.item())
        self.train_info['actor_losses'].append(actor_contrastive_loss.item())
        self.train_info['bc_losses'].append(bc_loss.item())
        self.total_steps += 1

        return {
            'critic_loss': critic_loss.item(),
            'actor_loss': actor_contrastive_loss.item(),
            'bc_loss': bc_loss.item()
        }
    def _update_patient_history(self, updated_history, action_np, output):
        updated_history = copy.deepcopy(updated_history)

        def _append_time_axis(array, new_val):
            new_val = new_val.reshape(array.shape[0], -1)
            new_storage = np.zeros((array.shape[0], array.shape[1] + 1, array.shape[2]), dtype=array.dtype)
            new_storage[:, :-1, :] = array
            new_storage[:, -1, :] = new_val
            return new_storage

        for key, value in {'current_treatments': action_np, 'outputs': output}.items():
            if key in updated_history:
                updated_history[key] = _append_time_axis(updated_history[key], value)

        if 'static_features' in updated_history:
            static = updated_history['static_features']
            new_len = updated_history['outputs'].shape[1]

            if static.ndim == 2:
                static_dim = static.shape[1]
                new_static = np.repeat(static[:, np.newaxis, :], new_len, axis=1)
                updated_history['static_features'] = new_static
            elif static.ndim == 3 and new_len > static.shape[1]:
                static_dim = static.shape[2]
                new_static = np.zeros((static.shape[0], new_len, static_dim), dtype=static.dtype)
                new_static[:, :static.shape[1], :] = static
                new_static[:, static.shape[1]:, :] = static[:, -1:, :]
                updated_history['static_features'] = new_static

        if 'vitals' in updated_history:
            prev_vitals = updated_history['vitals']
            new_vitals = np.zeros(
                (prev_vitals.shape[0], prev_vitals.shape[1] + 1, prev_vitals.shape[2]),
                dtype=prev_vitals.dtype
            )
            new_vitals[:, :-1, :] = prev_vitals
            if 'future_vitals' in updated_history and updated_history['future_vitals'].shape[1] > 0:
                new_vitals[:, -1, :] = updated_history['future_vitals'][:, :1, :]
                updated_history['future_vitals'] = updated_history['future_vitals'][:, 1:, :]
            updated_history['vitals'] = new_vitals

        return updated_history

    def generate_treatment_plan_batch(
        self,
        history_dict_batch,
        goal_batch,
        dataset_collection,
        future_dict_batch,
        future_length=None,
        early_stop=True
    ):
        if future_length is None:
            future_length = self.future_length

        self.actor.eval()
        self.encoder.eval()

        batch_size = len(history_dict_batch)
        device = DEVICE

        H_t_batch_processed = self._process_history_batch(history_dict_batch)

        goal_tensor_batch = []
        goal_np_batch = []
        for g in goal_batch:
            if isinstance(g, np.ndarray):
                goal_tensor = torch.as_tensor(g, dtype=torch.float32).unsqueeze(0)
                goal_np_batch.append(g)
            elif isinstance(g, torch.Tensor):
                goal_tensor = g.detach().cpu().float()
                goal_np_batch.append(goal_tensor.numpy())
                goal_tensor = goal_tensor.unsqueeze(0)
            else:
                raise TypeError("Unsupported goal type during planning.")
            goal_tensor_batch.append(goal_tensor)
        goal_tensor_batch = torch.cat(goal_tensor_batch, dim=0).to(device)

        with torch.no_grad():
            _, goal_encoding_batch = self.encoder(H_t_batch_processed, goal_tensor_batch)

        updated_history_batch_dicts = [copy.deepcopy(h) for h in history_dict_batch]
        current_H_t_processed = H_t_batch_processed

        actions_batch = [[] for _ in range(batch_size)]
        outputs_batch = [[] for _ in range(batch_size)]
        steps_taken_batch = [future_length] * batch_size

        for t in range(future_length):
            with torch.no_grad():
                history_encoding_batch, _ = self.encoder(current_H_t_processed, goal_tensor_batch)
                actor_state_batch = torch.cat([history_encoding_batch, goal_encoding_batch], dim=1)
                action_batch = self.actor(actor_state_batch, deterministic=True)

            for i in range(batch_size):
                action_np = action_batch[i:i + 1].cpu().numpy().reshape(1, 1, -1)
                actions_batch[i].append(action_np)

                actions_tensor = np.concatenate(actions_batch[i], axis=1)
                actions_tensor_torch = torch.as_tensor(actions_tensor, dtype=torch.float32, device=device)

                output = dataset_collection.val_f.simulate_output_after_actions(
                    history_dict_batch[i],
                    actions_tensor_torch,
                    dataset_collection.train_scaling_params
                )
                outputs_batch[i].append(output)

                if early_stop and np.linalg.norm(output - goal_np_batch[i]) < self.goal_threshold * 0.001:
                    steps_taken_batch[i] = min(steps_taken_batch[i], t + 1)

            if t == future_length - 1:
                break

            next_histories = []
            for i in range(batch_size):
                action_np = action_batch[i:i + 1].cpu().numpy().reshape(1, 1, -1)
                output = outputs_batch[i][-1]
                updated_history_batch_dicts[i] = self._update_patient_history(
                    updated_history_batch_dicts[i],
                    action_np,
                    output
                )
                next_histories.append(updated_history_batch_dicts[i])

            current_H_t_processed = self._process_history_batch(next_histories)

        self.actor.train()
        self.encoder.train()

        actions_batch_np = [np.array(actions).squeeze(1) if actions else np.array([]) for actions in actions_batch]
        outputs_batch_np = [np.array(outputs) if outputs else np.array([]) for outputs in outputs_batch]

        return actions_batch_np, outputs_batch_np, steps_taken_batch
    def train_offline(self, iterations, progress_interval=100, eval_interval=5000, dataset_collection=None):
        print(f"\n开始离线训练 ({self.algorithm})...")
        train_start = time.time()

        losses = []
        for i in range(iterations):
            loss_info = self.update_parameters()
            if loss_info is not None:
                losses.append(loss_info)

            if (i + 1) % progress_interval == 0 and losses:
                recent_losses = losses[-progress_interval:]
                avg_critic_loss = np.mean([l['critic_loss'] for l in recent_losses])
                avg_actor_loss = np.mean([l['actor_loss'] for l in recent_losses])
                avg_bc_loss = np.mean([l['bc_loss'] for l in recent_losses])

                print(
                    f"迭代 {i + 1}/{iterations}, "
                    f"Critic损失: {avg_critic_loss:.4f}, "
                    f"Actor损失: {avg_actor_loss:.4f}, "
                    f"BC损失: {avg_bc_loss:.4f}"
                )

        train_time = time.time() - train_start
        print(f"训练完成，用时: {train_time:.1f}秒")
        return losses
    def save(self, path):
        save_dict = {
            'encoder': self.encoder.state_dict(),
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'train_info': self.train_info,
            'total_steps': self.total_steps,
            'algorithm': self.algorithm,
            'neg_queue': self.neg_queue if self.neg_queue is not None else None,
            'neg_queue_ptr': self.neg_queue_ptr,
            'neg_queue_filled': self.neg_queue_filled,
            'temperature': self.temperature,
            'representation_dim': self.representation_dim,
            'queue_size': self.queue_size
        }
        print(f"SCRL 模型保存到 {path}")

    def load(self, path):
        checkpoint = torch.load(path, map_location=DEVICE)
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])

        self.train_info = checkpoint['train_info']
        self.total_steps = checkpoint['total_steps']
        if 'algorithm' in checkpoint:
            self.algorithm = checkpoint['algorithm']

        if 'temperature' in checkpoint:
            self.temperature = checkpoint['temperature']
        if 'representation_dim' in checkpoint:
            if checkpoint['representation_dim'] != self.representation_dim:
                raise ValueError("Loaded representation_dim mismatch.")
        if 'queue_size' in checkpoint and checkpoint['queue_size'] != self.queue_size:
            print("Warning: Loaded queue size differs from current configuration; resizing queue.")

        if (
            'neg_queue' in checkpoint
            and checkpoint['neg_queue'] is not None
            and self.queue_size > 0
        ):
            queue_tensor = checkpoint['neg_queue'].to(DEVICE)
            if queue_tensor.shape[1] != self.representation_dim:
                raise ValueError("Loaded queue representation_dim mismatch.")

            if queue_tensor.shape[0] != self.queue_size:
                resized_queue = torch.zeros(self.queue_size, self.representation_dim, device=DEVICE)
                copy_len = min(queue_tensor.shape[0], self.queue_size)
                resized_queue[:copy_len] = queue_tensor[:copy_len]
                self.neg_queue = resized_queue
                self.neg_queue_filled = min(checkpoint.get('neg_queue_filled', copy_len), self.queue_size)
                self.neg_queue_ptr = checkpoint.get('neg_queue_ptr', 0) % self.queue_size
            else:
                self.neg_queue = queue_tensor
                self.neg_queue_ptr = checkpoint.get('neg_queue_ptr', 0) % self.queue_size
                self.neg_queue_filled = min(checkpoint.get('neg_queue_filled', self.queue_size), self.queue_size)

        print(f"从 {path} 加载模型 ({self.algorithm})")