import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import random
from collections import deque
from torch.nn.utils.rnn import pad_sequence

from src.gift.models.encoders import HistoryEncoder
from src.gift.models.networks import Actor, Critic as EnhancedCritic
from src.gift.buffers.her_buffer import HERReplayBuffer, search_reward_threshold_adaptive
# from src.her.buffers.prioritized_buffer import PrioritizedReplayBuffer as HERReplayBuffer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from src.gift.utils.evaluator import evaluate_agent

#--- New: IQL Helper Functions and Networks ---
def expectile_loss(diff, tau):
    """Expectile losses used in IQL"""
    weight = torch.where(diff > 0, tau, 1 - tau)
    return (weight * (diff**2)).mean()

class ValueNetwork(nn.Module):
    """Value Network V (s) used in IQL"""
    def __init__(self, state_dim, hiddens_sac=[256, 256]):
        super(ValueNetwork, self).__init__()
        layers = []
        input_dim = state_dim
        for hidden_dim in hiddens_sac:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
        layers.append(nn.Linear(input_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, state):
        return self.net(state)
#--- Add End ---


class BehaviorPolicy(nn.Module):
    """
    Behavioral strategy model, using a Beta distribution with an action range between (0,1).
    This release uses a list `hiddens_bhvr` to dynamically build hidden layers for greater flexibility.
    """
    #Modified: Replace hidden_dim parameter with hiddens_bhvr list
    def __init__(self, state_dim, action_dim, hiddens_bhvr=[256, 256]):
        super(BehaviorPolicy, self).__init__()
        
        #--- Changes Begin: Dynamically Build Network ---
        
        #If hiddens_bhvr is empty, an error is raised because at least one hidden layer is required
        if not hiddens_bhvr:
            raise ValueError("the hiddens_bhvr list cannot be empty.")

        #Create an empty list to store hidden layers
        net_layers = []
        input_dim = state_dim
        
        #Traverse the hiddens_bhvr list to create a linear layer and ReLU activation function for each dimension
        for hidden_dim in hiddens_bhvr:
            net_layers.append(nn.Linear(input_dim, hidden_dim))
            net_layers.append(nn.ReLU())
            #Update the input dimension of the next layer
            input_dim = hidden_dim
        
        #Package all hidden layers into one module with nn.Sequential
        self.net = nn.Sequential(*net_layers)
        
        #Create final output layer
        #Its input dimension is the output dimension of the last hidden layer (i.e. updated input_dim)
        #The output dimension is action_dim * 2 because both alpha and beta parameters need to be provided for the beta distribution
        self.output_layer = nn.Linear(input_dim, action_dim * 2)
        
        #--- End of change ---
        
    def forward(self, state, action=None):
        #Step 1: Pass the state through all hidden layers
        x = self.net(state)
        
        #Step 2: Send the output of the hidden layer to the final layer to obtain the logits of the distribution parameters
        x = self.output_layer(x)
        
        #--- The following logic is exactly the same as the original ---
        
        #Separate logits for alpha and beta parameters
        alpha_logits, beta_logits = torch.chunk(x, 2, dim=-1)
        
        #Use softplus to ensure alpha and beta parameters are greater than 1, which helps to stabilize training
        #Beta (1, 1) is evenly distributed, greater than 1 can make it a peak distribution
        alpha = F.softplus(alpha_logits) + 1.0
        beta = F.softplus(beta_logits) + 1.0
        
        #Create Beta Distribution Instance
        beta_dist = torch.distributions.Beta(alpha, beta)
        
        if action is None:
            #If no action is provided, a new action is sampled from the distribution
            #Use rsample () to support reparametrization techniques that allow gradient backhaul
            action = beta_dist.rsample()
            
            #Calculate the logarithmic probability of the sampling action
            log_prob = beta_dist.log_prob(action)
            #Sum the logarithmic probability of the action dimension (if action_dim > 1)
            log_prob = log_prob.sum(-1, keepdim=True)
            
            return action, log_prob
        else:
            #If an action is provided, calculate the logarithmic probability of that action under that distribution
            #Crop the action to (0, 1) open interval before calculating log_prob to avoid numerical issues
            action_clamped = torch.clamp(action, 1e-6, 1 - 1e-6)
            log_prob = beta_dist.log_prob(action_clamped)
            
            #Sum of logarithmic probabilities of action dimensions
            log_prob = log_prob.sum(-1, keepdim=True)
            
            return log_prob

class SAC_HER_Agent:
    def __init__(self, dataset_collection, config, input_dim=1, output_dim=1, treatment_dim=2, static_dim=1, hidden_dim=128, future_length=5,
                 discount=0.99, beta=0.005, lr=3e-4, alpha=0.2, 
                 buffer_size=200000, batch_size=128, use_automatic_entropy=True,
                 goal_threshold=1e-3, k_future=10, use_amp=False, reward_mode='combined',
                 use_attention=False, num_heads=4, DR=True, recover=False, action_diff=True,
                 use_cql=False,
                 cql_alpha=5.0,
                 cql_n_actions=10,
                 iql_tau=0.7,     #Added: IQL expectile
                 iql_beta=3.0,    #Added: IQL advantage weight
                 actor_update_interval=5,
                 input_x=False):

        self.hiddens_enc = config.model.hiddens_enc
        self.hiddens_sac = config.model.hiddens_sac
        self.hiddens_bhvr = config.model.hiddens_bhvr
        self.state_dim = self.hiddens_enc[-1]
        
        #Core Parameters
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.treatment_dim = treatment_dim
        self.static_dim = static_dim
        self.hidden_dim = hidden_dim
        self.future_length = future_length
        self.discount = discount
        self.beta = beta
        self.alpha = alpha
        self.batch_size = batch_size
        self.goal_threshold = goal_threshold
        self.use_amp = use_amp
        
        #--- Modify: Algorithm Selection Logic ---
        self.baserl = config.model.baserl
        self.DR = DR
        if self.baserl in ['CQL', 'IQL']:
            self.DR = False #Disable DR if using CQL or IQL
        
        # CQL-specific parameters
        self.cql_alpha = cql_alpha
        self.cql_n_actions = cql_n_actions

        # IQL-specific parameters
        self.iql_tau = iql_tau
        self.iql_beta = iql_beta
        #--- End of Modification ---

        self.actor_update_interval = actor_update_interval
        self.critic_updates = 0
        
        self.recover = recover
        self.action_diff = action_diff
        self.input_x = input_x
        
        #History Encoder
        self.encoder = HistoryEncoder(
            input_dim=input_dim, output_dim=output_dim, treatment_dim=treatment_dim,
            static_dim=static_dim, hiddens_enc=self.hiddens_enc, use_attention=use_attention,
            num_heads=num_heads, use_spectral_norm=recover, input_x=self.input_x
        ).to(DEVICE)
        
        #Actor Network
        self.actor = Actor(
            state_dim=self.state_dim, action_dim=treatment_dim, hiddens_sac=self.hiddens_sac
        ).to(DEVICE)
        
        #Critic Network
        self.critic = EnhancedCritic(
            state_dim=self.state_dim, action_dim=treatment_dim, hiddens_sac=self.hiddens_sac
        ).to(DEVICE)
        
        #Target Critic Network
        self.critic_target = EnhancedCritic(
            state_dim=self.state_dim, action_dim=treatment_dim, hiddens_sac=self.hiddens_sac
        ).to(DEVICE)
        self.critic_target.load_state_dict(self.critic.state_dict())

        #--- New: IQL's Value Network ---
        if self.baserl == 'IQL':
            self.value_net = ValueNetwork(self.state_dim, self.hiddens_sac).to(DEVICE)
            self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
        
        #Target Predictor
        if self.recover:
            self.target_predictor = TargetPredictor(
                state_dim=hidden_dim, goal_dim=output_dim, hidden_dim=hidden_dim*2
            ).to(DEVICE)
            self.target_predictor_optimizer = optim.Adam(self.target_predictor.parameters(), lr=lr)
        
        #Behavioral Strategy Network
        if self.DR or self.action_diff:
            self.behavior_policy = BehaviorPolicy(
                state_dim=self.state_dim, action_dim=treatment_dim, hiddens_bhvr=self.hiddens_bhvr
            ).to(DEVICE)
            self.behavior_optimizer = optim.Adam(self.behavior_policy.parameters(), lr=lr)
        
        #Optimizer
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=lr)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        
        #Automatic entropy adjustment
        self.use_automatic_entropy = use_automatic_entropy
        self.target_entropy = -float(treatment_dim)
        if use_automatic_entropy:
            self.log_alpha = torch.zeros(1, requires_grad=True, device=DEVICE)
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
            self.alpha = self.log_alpha.exp().item()
        
        #Experience playback buffer
        min_history_length, max_history_length = config['model']['her_params']['min_history_length'], config['model']['her_params']['max_history_length']
        future_length, target_hit_ratio = config['model']['her_params']['future_length'], config['model']['her_params']['target_hit_ratio']
        optimal_threshold = search_reward_threshold_adaptive(
            dataset_collection, min_history_length=min_history_length, max_history_length=max_history_length,
            capacity=buffer_size, k_future=k_future, future_length=future_length, reward_mode=reward_mode,
            initial_threshold=goal_threshold, target_hit_ratio=target_hit_ratio, max_iter=10, hit_tolerance=0.01, max_high_limit=1
        )
        self.memory = HERReplayBuffer(buffer_size, k_future=k_future, reward_threshold=optimal_threshold, reward_mode=reward_mode)
        
        #Mixed Precision Training
        if use_amp and torch.cuda.is_available(): self.scaler = torch.cuda.amp.GradScaler()
        
        #Record training information
        self.train_info = {
            'critic_losses': [], 'actor_losses': [], 'alpha_losses': [], 'alphas': [],
            'behavior_losses': [], 'target_losses': [], 'dynamics_losses': [], 'action_diffs': [],
            'value_losses': [] # For IQL
        }
        
        self.total_steps = 0
        
        #Algorithm Type
        if self.baserl != 'None':
            self.algorithm = self.baserl
        else:
            self.algorithm = "SAC-DR" if self.DR else "SAC"
        
        if self.recover: self.algorithm += "-Recover"
        if self.action_diff: self.algorithm += "-ActionDiff"

    def update_parameters(self):
        """Update network parameters"""
        batch = self.memory.sample(self.batch_size)
        if batch is None: return None
        
        history_batch, actions, rewards, next_history_batch, goals, dones = self._process_batch(batch)
        
        # For now, only standard update is supported with the new logic
        return self._update_standard(history_batch, actions, rewards, next_history_batch, goals, dones)
    
    def _update_standard(self, history_batch, actions, rewards, next_history_batch, goals, dones):
        """Standard update with logic for SAC/DR, CQL, IQL"""
        #1) Encoding status
        state = self.encoder(history_batch, goals)
        next_state = self.encoder(next_history_batch, goals)

        #2) Update Behavioral Strategies (if needed)
        behavior_loss = torch.tensor(0.0, device=DEVICE)
        if self.DR or self.action_diff:
            log_probs = self.behavior_policy(state.detach(), actions)
            behavior_loss = -log_probs.mean()
            self.behavior_optimizer.zero_grad()
            behavior_loss.backward()
            self.behavior_optimizer.step()
            state = self.encoder(history_batch, goals) #&Recode...	F8

        #Initialization loss
        critic_loss = torch.tensor(0.0, device=DEVICE)
        value_loss = torch.tensor(0.0, device=DEVICE)

        #3) --- Update Critic (and Value Network for IQL) ---
        if self.baserl == 'IQL':
            # --- IQL Value Function and Critic Update ---
            with torch.no_grad():
                q1_target, q2_target = self.critic_target(state, actions)
                target_q_for_v = torch.min(q1_target, q2_target)
            
            #3a) Update Value Network
            v = self.value_net(state)
            value_loss = expectile_loss(target_q_for_v - v, self.iql_tau)
            self.value_optimizer.zero_grad(set_to_none=True)
            value_loss.backward()
            self.value_optimizer.step()
            
            #3b) Update Critic Network
            with torch.no_grad():
                next_v = self.value_net(next_state)
                target_q = rewards + (1 - dones) * self.discount * next_v
            
            current_q1, current_q2 = self.critic(state, actions)
            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

        else: # For SAC/DR and CQL
            # --- SAC/DR/CQL Critic Update ---
            with torch.no_grad():
                next_actions, next_log_probs = self.actor(next_state)
                target_q1, target_q2 = self.critic_target(next_state, next_actions)
                target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_probs

                if self.DR:
                    policy_log_probs = self.actor.log_prob(state, actions)
                    behavior_log_probs = self.behavior_policy(state, actions)
                    rho = torch.exp(policy_log_probs - behavior_log_probs).clamp(0.01, 10.0)
                    target_q = rewards * rho + (1 - dones) * self.discount * target_q
                else:
                    target_q = rewards + (1 - dones) * self.discount * target_q
            
            current_q1, current_q2 = self.critic(state, actions)
            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

            if self.baserl == 'CQL':
                #--- CQL Regularization ---
                random_actions = torch.rand(self.batch_size, self.cql_n_actions, self.treatment_dim, device=DEVICE)
                state_expanded = state.unsqueeze(1).repeat(1, self.cql_n_actions, 1)
                
                with torch.no_grad():
                    policy_actions, _ = self.actor(state_expanded.reshape(-1, state.shape[-1]))
                    policy_actions = policy_actions.reshape(self.batch_size, self.cql_n_actions, self.treatment_dim)

                q1_rand, q2_rand = self.critic(state_expanded.reshape(-1, state.shape[-1]), random_actions.reshape(-1, self.treatment_dim))
                q1_pi, q2_pi = self.critic(state_expanded.reshape(-1, state.shape[-1]), policy_actions.reshape(-1, self.treatment_dim))

                cat_q1 = torch.cat([q1_rand.reshape(self.batch_size, -1), q1_pi.reshape(self.batch_size, -1)], dim=1)
                cat_q2 = torch.cat([q2_rand.reshape(self.batch_size, -1), q2_pi.reshape(self.batch_size, -1)], dim=1)

                lse_q1 = torch.logsumexp(cat_q1, dim=1, keepdim=True)
                lse_q2 = torch.logsumexp(cat_q2, dim=1, keepdim=True)
                
                cql_loss = ((lse_q1 - current_q1).mean() + (lse_q2 - current_q2).mean()) * self.cql_alpha
                critic_loss += cql_loss

        #4) Perform Critic and Encoder optimizations
        self.encoder_optimizer.zero_grad(set_to_none=True)
        self.critic_optimizer.zero_grad(set_to_none=True)
        critic_loss.backward()
        self.critic_optimizer.step()
        self.encoder_optimizer.step()
        self.critic_updates += 1

        #5) --- Delayed Updating Actor ---
        actor_loss = torch.tensor(0.0, device=DEVICE)
        alpha_loss = torch.tensor(0.0, device=DEVICE)
        action_diff_value = torch.tensor(0.0, device=DEVICE)

        if self.critic_updates % self.actor_update_interval == 0:
            state_det = self.encoder(history_batch, goals).detach()
            
            if self.baserl == 'IQL':
                # --- IQL Actor Update ---
                with torch.no_grad():
                    q1, q2 = self.critic_target(state_det, actions)
                    q = torch.min(q1, q2)
                    v = self.value_net(state_det)
                    adv = q - v
                exp_adv = torch.exp(self.iql_beta * adv).clamp(max=100.0)
                policy_log_probs = self.actor.log_prob(state_det, actions)
                actor_loss = -(exp_adv * policy_log_probs).mean()
            else:
                # --- SAC/DR/CQL Actor Update ---
                new_actions, log_probs = self.actor(state_det)
                q1, q2 = self.critic(state_det, new_actions)
                q = torch.min(q1, q2)
                actor_loss = (self.alpha * log_probs - q).mean()

            if self.action_diff:
                with torch.no_grad(): behavior_log_probs = self.behavior_policy(state_det, new_actions)
                action_diff_value = (log_probs - behavior_log_probs).mean()
                diff_weight = min(0.01, 1.0 * np.exp(-0.001 * self.total_steps))
                actor_loss += diff_weight * action_diff_value
            
            self.actor_optimizer.zero_grad(set_to_none=True)
            actor_loss.backward()
            self.actor_optimizer.step()

            #6) Update Alpha (for SAC/CQL, not for IQL)
            if self.baserl != 'IQL' and self.use_automatic_entropy:
                alpha_loss = - (self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.alpha = self.log_alpha.exp().item()

            #7) Soft Update Target Network
            for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - self.beta) + param.data * self.beta)
        
        #--- Record training information ---
        self.train_info['critic_losses'].append(critic_loss.item())
        self.train_info['actor_losses'].append(actor_loss.item())
        self.train_info['behavior_losses'].append(behavior_loss.item())
        if self.baserl != 'IQL':
            self.train_info['alpha_losses'].append(alpha_loss.item())
            self.train_info['alphas'].append(self.alpha)
        if self.baserl == 'IQL':
            self.train_info['value_losses'].append(value_loss.item())
        if self.action_diff:
            self.train_info['action_diffs'].append(action_diff_value.item())

        self.total_steps += 1
        return {
            'critic_loss': critic_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha_loss': alpha_loss.item(),
            'alpha': self.alpha if self.baserl != 'IQL' else 0.0,
            'behavior_loss': behavior_loss.item(),
            'value_loss': value_loss.item() if self.baserl == 'IQL' else 0.0,
            'action_diff': action_diff_value.item() if self.action_diff else 0.0
        }

    def _process_batch(self, batch):
        """Process batch data, convert to tensor"""
        if isinstance(batch, tuple) and len(batch) == 3:  #Per Format
            batch, _, _ = batch
            
        history_dicts, actions, rewards, next_history_dicts, goals, dones = [], [], [], [], [], []
        
        for exp in batch:
            history_dicts.append(exp.history_dict)
            actions.append(exp.action)
            rewards.append(exp.reward)
            next_history_dicts.append(exp.next_history_dict)
            goals.append(exp.goal)
            dones.append(exp.done)
        
        actions = torch.FloatTensor(np.array(actions)).to(DEVICE)
        rewards = torch.FloatTensor(np.array(rewards)).reshape(-1, 1).to(DEVICE)
        goals = torch.FloatTensor(np.array(goals)).to(DEVICE)
        dones = torch.FloatTensor(np.array(dones)).reshape(-1, 1).to(DEVICE)
        
        processed_history = self._process_history_batch(history_dicts)
        processed_next_history = self._process_history_batch(next_history_dicts)
        
        return processed_history, actions, rewards, processed_next_history, goals, dones
    
    def _process_history_batch(self, history_dicts):
        """Process historical dictionary batches, process variable-length sequences"""
        batch_size = len(history_dicts)
        outputs_list, static_list, treatments_list, vitals_list = [], [], [], []
        seq_lengths = []
        
        for d in history_dicts:
            outputs, static_features, current_treatments = d['outputs'], d['static_features'], d['current_treatments']
            seq_lengths.append(outputs.shape[1])
            outputs_list.append(torch.FloatTensor(outputs[0]))
            static_list.append(torch.FloatTensor(static_features[0]))
            treatments_list.append(torch.FloatTensor(current_treatments[0]))
            if self.input_x: vitals_list.append(torch.FloatTensor(d['vitals'][0]))
        
        padded_outputs = pad_sequence(outputs_list, batch_first=True)
        padded_static = pad_sequence(static_list, batch_first=True)
        padded_treatments = pad_sequence(treatments_list, batch_first=True)
        
        batch_history = {
            'outputs': padded_outputs.to(DEVICE),
            'static_features': padded_static.to(DEVICE),
            'current_treatments': padded_treatments.to(DEVICE),
            'seq_lengths': torch.LongTensor(seq_lengths).to(DEVICE)
        }
        if self.input_x:
            padded_vitals = pad_sequence(vitals_list, batch_first=True)
            batch_history['vitals'] = padded_vitals.to(DEVICE)
        
        return batch_history
    
    def select_action(self, history_dict, goal, evaluate=False):
        """Select one-step action"""
        history_tensor_dict = {k: torch.FloatTensor(v).to(DEVICE) if isinstance(v, np.ndarray) else v for k, v in history_dict.items()}
        goal_tensor = torch.FloatTensor(goal).to(DEVICE) if isinstance(goal, np.ndarray) else goal.to(DEVICE)
        
        with torch.no_grad():
            state = self.encoder(history_tensor_dict, goal_tensor)
            action = self.actor(state, deterministic=True) if evaluate else self.actor(state, deterministic=False)[0]
        return action.cpu().numpy()

    def generate_treatment_plan(self, history_dict, goal, dataset_collection, future_dict, future_length=None, early_stop=True):
        """
        Generate treatment plan step by step, update status with simulate_output_after_actions at each step
        
        Args:
            history_dict: History Status Dictionary
            goal: Goal status
            dataset_collection: A collection of data containing simulate_output_after_actions
            future_length: Plan length, defaults to self.future_length
            early_stop: Whether to stop early when the goal is reached
            
        Pingback:
            actions: Generated action sequence
            outputs: Outputs per step
            steps_taken: The number of steps actually performed
        """
        if future_length is None:
            future_length = self.future_length
            
        #Evaluation Mode
        self.actor.eval()
        self.encoder.eval()
        
        #Convert historical data to tensor
        history_tensor_dict = {}
        for key in history_dict:
            if isinstance(history_dict[key], np.ndarray):
                history_tensor_dict[key] = torch.FloatTensor(history_dict[key]).to(DEVICE)
            else:
                history_tensor_dict[key] = history_dict[key]
                
        goal_tensor = torch.FloatTensor(goal).to(DEVICE) if isinstance(goal, np.ndarray) else goal.to(DEVICE)
        goal_np = goal if isinstance(goal, np.ndarray) else goal.cpu().numpy()
        
        #Prepare to record actions and results
        actions = []
        outputs = []
        
        #Create a copy of the history that can be modified
        updated_history = {}
        for key in history_dict:
            if isinstance(history_dict[key], np.ndarray):
                updated_history[key] = history_dict[key].copy()
            else:
                updated_history[key] = history_dict[key]
        
        #Current History Object
        current_history = history_tensor_dict
        
        #Generate actions step-by-step
        steps_taken = 0
        for t in range(future_length):
            with torch.no_grad():
                #Encode Current Status
                state = self.encoder(current_history, goal_tensor)
                
                #Select the current best action
                action = self.actor(state, deterministic=True).unsqueeze(1)
                action_np = action.cpu().numpy()
                
                #Log action
                actions.append(action_np)
                
                if len(actions) == 1:
                    actions_tensor = actions[0]
                else:
                    actions_tensor = np.concatenate(actions, axis=1) 
                     
                actions_tensor = torch.FloatTensor(actions_tensor).to(DEVICE)
                
                #Calling Emulation Functions
                output = dataset_collection.val_f.simulate_output_after_actions(
                    history_dict,
                    actions_tensor,
                    dataset_collection.train_scaling_params,
                )
                
                outputs.append(output)
                steps_taken += 1
                
                #Check if the goal is met
                if early_stop and np.linalg.norm(output - goal_np) < self.goal_threshold * 0.001:
                    break
                
                #Last step does not need to update history
                if t == future_length - 1:
                    break
                
                #Update History
                if 'current_treatments' in updated_history:
                    prev_treatments = updated_history['current_treatments']
                    seq_len = prev_treatments.shape[1]
                    new_treatments = np.zeros((prev_treatments.shape[0], seq_len + 1, prev_treatments.shape[2]))
                    new_treatments[:, :-1, :] = prev_treatments
                    new_treatments[:, -1, :] = action_np
                    updated_history['current_treatments'] = new_treatments
                
                if 'outputs' in updated_history:
                    prev_outputs = updated_history['outputs']
                    seq_len = prev_outputs.shape[1]
                    new_outputs = np.zeros((prev_outputs.shape[0], seq_len + 1, prev_outputs.shape[2]))
                    new_outputs[:, :-1, :] = prev_outputs
                    new_outputs[:, -1, :] = output
                    updated_history['outputs'] = new_outputs
                
                if 'static_features' in updated_history:
                    static_features = updated_history['static_features']
                    seq_len = static_features.shape[1]
                    new_len = updated_history['outputs'].shape[1]
                    if new_len > seq_len:
                        new_static = np.zeros((static_features.shape[0], new_len, static_features.shape[2]))
                        new_static[:, :seq_len, :] = static_features
                        for i in range(seq_len, new_len):
                            new_static[:, i, :] = static_features[:, -1, :]
                        updated_history['static_features'] = new_static

                if 'vitals' in updated_history:
                    prev_vitals = updated_history['vitals']
                    seq_len = prev_vitals.shape[1]
                    new_vitals = np.zeros((prev_vitals.shape[0], seq_len + 1, prev_vitals.shape[2]))
                    new_vitals[:, :-1, :] = prev_vitals
                    new_vitals[:, -1, :] = updated_history['future_vitals'][:, :1, :]
                    updated_history['vitals'] = new_vitals
                
                if 'future_vitals' in updated_history:
                    updated_history['future_vitals'] = updated_history['future_vitals'][:, 1:, :]
                
                #Convert updated history to tensor for next use
                current_history = {k: torch.FloatTensor(v).to(DEVICE) if isinstance(v, np.ndarray) else v for k, v in updated_history.items()}
        
        #Reset to Training Mode
        self.actor.train()
        self.encoder.train()
        
        actions = np.array(actions)
        outputs = np.array(outputs) if outputs else np.array([])
        
        return actions, outputs, steps_taken
    
    def generate_treatment_plan_batch(self, history_dict_batch, goal_batch, dataset_collection, future_dict_batch, future_length=None, early_stop=True):
        """
        Batch version of the build treatment plan function
        """
        if future_length is None:
            future_length = self.future_length

        self.actor.eval()
        self.encoder.eval()

        batch_size = len(history_dict_batch)
        device = DEVICE

        H_t_batch = {}
        for history_dict in history_dict_batch:
            for key, value in history_dict.items():
                if key not in H_t_batch: H_t_batch[key] = []
                H_t_batch[key].append(torch.FloatTensor(value) if isinstance(value, np.ndarray) else value)

        for key in H_t_batch:
            if isinstance(H_t_batch[key][0], torch.Tensor):
                H_t_batch[key] = torch.cat(H_t_batch[key], dim=0).to(device)

        goal_tensor_batch = [torch.FloatTensor(g).unsqueeze(0) if isinstance(g, np.ndarray) else (g.unsqueeze(0) if g.dim() == 1 else g) for g in goal_batch]
        goal_tensor_batch = torch.cat(goal_tensor_batch, dim=0).to(device)
        goal_np_batch = [g if isinstance(g, np.ndarray) else g.cpu().numpy() for g in goal_batch]

        updated_history_batch = []
        for history_dict in history_dict_batch:
            updated_history = {k: v.copy() if isinstance(v, np.ndarray) else (v.cpu().numpy().copy() if hasattr(v, 'cpu') else v) for k, v in history_dict.items()}
            updated_history_batch.append(updated_history)

        current_H_t_batch = H_t_batch
        actions_batch = [[] for _ in range(batch_size)]
        outputs_batch = [[] for _ in range(batch_size)]
        steps_taken_batch = [future_length] * batch_size

        for t in range(future_length):
            with torch.no_grad():
                state_batch = self.encoder(current_H_t_batch, goal_tensor_batch)
                action_batch = self.actor(state_batch, deterministic=True)
                
                for i in range(batch_size):
                    action_np = action_batch[i:i+1].cpu().numpy().reshape(1, 1, -1)
                    actions_batch[i].append(action_np)

                    actions_tensor = np.concatenate(actions_batch[i], axis=1)
                    actions_tensor = torch.FloatTensor(actions_tensor).to(device)

                    output = dataset_collection.val_f.simulate_output_after_actions(
                        history_dict_batch[i], actions_tensor, dataset_collection.train_scaling_params
                    )
                    outputs_batch[i].append(output)

                    if early_stop and np.linalg.norm(output - goal_np_batch[i]) < self.goal_threshold * 0.001 and steps_taken_batch[i] > t + 1:
                        steps_taken_batch[i] = t + 1

                if t == future_length - 1: break

                for i in range(batch_size):
                    action_np = action_batch[i:i+1].cpu().numpy().reshape(1, 1, -1)
                    output = outputs_batch[i][-1]
                    updated_history_batch[i] = self._update_patient_history(updated_history_batch[i], action_np, output)

                current_H_t_batch = {}
                for key in updated_history_batch[0]:
                    batch_data = [torch.FloatTensor(h[key]) if isinstance(h[key], np.ndarray) else h[key] for h in updated_history_batch]
                    if isinstance(batch_data[0], torch.Tensor):
                        current_H_t_batch[key] = torch.cat(batch_data, dim=0).to(device)

        self.actor.train()
        self.encoder.train()

        actions_batch = [np.array(actions).squeeze(1) if actions else np.array([]) for actions in actions_batch]
        outputs_batch = [np.array(outputs) if outputs else np.array([]) for outputs in outputs_batch]

        return actions_batch, outputs_batch, steps_taken_batch

    def _update_patient_history(self, updated_history, action_np, output):
        """
        Update Patient History
        """
        for key, value in {'current_treatments': action_np, 'outputs': output}.items():
            if key in updated_history:
                prev = updated_history[key]
                new_val = np.zeros((prev.shape[0], prev.shape[1] + 1, prev.shape[2]))
                new_val[:, :-1, :] = prev
                new_val[:, -1, :] = value.reshape(prev.shape[0], -1)
                updated_history[key] = new_val

        if 'static_features' in updated_history:
            static = updated_history['static_features']
            new_len = updated_history['outputs'].shape[1]
            if new_len > static.shape[1]:
                new_static = np.zeros((static.shape[0], new_len, static.shape[2]))
                new_static[:, :static.shape[1], :] = static
                new_static[:, static.shape[1]:, :] = static[:, -1:, :]
                updated_history['static_features'] = new_static

        if 'vitals' in updated_history:
            prev_vitals = updated_history['vitals']
            new_vitals = np.zeros((prev_vitals.shape[0], prev_vitals.shape[1] + 1, prev_vitals.shape[2]))
            new_vitals[:, :-1, :] = prev_vitals
            if 'future_vitals' in updated_history and updated_history['future_vitals'].shape[1] > 0:
                new_vitals[:, -1, :] = updated_history['future_vitals'][:, :1, :]
                updated_history['future_vitals'] = updated_history['future_vitals'][:, 1:, :]
            updated_history['vitals'] = new_vitals

        return updated_history

    def train_offline(self, iterations, progress_interval=100, eval_interval=5000, dataset_collection=None):
        """Offline training"""
        print(f"\ nStart offline training ({self.algorithm})...")
        train_start = time.time()
        
        losses = []
        eval_results = []
        best_rmse = float('inf')
        best_iter = 0
        
        for i in range(iterations):
            loss_info = self.update_parameters()
            if loss_info is not None: losses.append(loss_info)
            
            if (i+1) % progress_interval == 0:
                recent_losses = losses[-progress_interval:]
                avg_critic_loss = np.mean([l['critic_loss'] for l in recent_losses])
                avg_actor_loss = np.mean([l['actor_loss'] for l in recent_losses])
                
                print_str = f"Iterations {i +1}/{iterations}, Critic loss: {avg_critic_loss: .4f}, Actor loss: {avg_actor_loss: .4f}"
                
                if self.baserl != 'IQL':
                    avg_alpha = np.mean([l['alpha'] for l in recent_losses])
                    print_str += f", Alpha: {avg_alpha:.4f}"
                if self.DR:
                    avg_behavior_loss = np.mean([l['behavior_loss'] for l in recent_losses])
                    print_str += f", Behavior loss: {avg_behavior_loss: .4f}"
                if self.baserl == 'IQL':
                    avg_value_loss = np.mean([l['value_loss'] for l in recent_losses])
                    print_str += f", Value loss: {avg_value_loss: .4f}"
                if self.action_diff:
                    avg_action_diff = np.mean([l['action_diff'] for l in recent_losses])
                    print_str += f", Action diff: {avg_action_diff: .4f}"
                    
                print(print_str)
            
            if dataset_collection is not None and (i+1) % eval_interval == 0:
                print(f"\ nEvaluation model (iteration {i +1})...")
                metrics = evaluate_agent(self, dataset_collection, num_episodes=100)
                eval_results.append((i+1, metrics))
                
                if metrics['avg_rmse'] < best_rmse:
                    best_rmse = metrics['avg_rmse']
                    best_iter = i+1
                    self.save(f"best_model_{self.algorithm}.pth")
                    
                print(f"Current best RMSE: {best_rmse: .6f} (iteration {best_iter})")
        
        train_time = time.time() - train_start
        print(f"Training completed in: {train_time: .1f} s")
        print(f"Best RMSE: {best_rmse: .6f} (Iteration {best_iter})")
        
        return losses, eval_results
    
    def save(self, path):
        """Save model"""
        save_dict = {
            'encoder': self.encoder.state_dict(),
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'critic_target': self.critic_target.state_dict(),
            'log_alpha': self.log_alpha if self.use_automatic_entropy and self.baserl != 'IQL' else None,
            'alpha': self.alpha,
            'train_info': self.train_info,
            'total_steps': self.total_steps,
            'algorithm': self.algorithm
        }
        
        if self.DR or self.action_diff:
            save_dict['behavior_policy'] = self.behavior_policy.state_dict()
        if self.recover:
            save_dict['target_predictor'] = self.target_predictor.state_dict()
        if self.baserl == 'IQL':
            save_dict['value_net'] = self.value_net.state_dict()
            
        torch.save(save_dict, path)
        print(f"Model saved to {path}")
    
    def load(self, path):
        """Load model"""
        checkpoint = torch.load(path, map_location=DEVICE)
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])
        self.critic_target.load_state_dict(checkpoint['critic_target'])
        
        if 'behavior_policy' in checkpoint and (self.DR or self.action_diff):
            self.behavior_policy.load_state_dict(checkpoint['behavior_policy'])
        if self.recover and 'target_predictor' in checkpoint:
            self.target_predictor.load_state_dict(checkpoint['target_predictor'])
        if self.baserl == 'IQL' and 'value_net' in checkpoint:
            self.value_net.load_state_dict(checkpoint['value_net'])
            
        if self.use_automatic_entropy and checkpoint.get('log_alpha') is not None and self.baserl != 'IQL':
            self.log_alpha.data.copy_(checkpoint['log_alpha'].data)
            
        self.alpha = checkpoint['alpha']
        self.train_info = checkpoint['train_info']
        self.total_steps = checkpoint['total_steps']
        if 'algorithm' in checkpoint: self.algorithm = checkpoint['algorithm']
        print(f"Load model from {path} ({self.algorithm})")

