# ddrqn.py
'''
Batch Constrained Deep Q-learning and Double DQN algorithms

Code related to the papers: Benchmarking Batch Deep Reinforcement Learning Algorithms &
Deep Reinforcement Learning with Double Q-learning

'''

import numpy as np
import tensorflow as tf
import random
from collections import deque
from . import bcq  # Module to use BCQ
from .Utils import prioritized_replay
from .Utils.rl_utils import (MODEL_PATH, STATE_SPACE, STATE_SIZE, MINIBATCH_SIZE, UPDATE_TARGET, EPOCHS, PER, ABS_ERROR_UPPER,
                             N_MODELS, DISCOUNT_FACTOR, ACTIONS, MAX_SIZE_MEMORY, create_model, feed_memory, save_models,
                             logging)

# DDRQN-specific constants
BCQ = True  # Toggle for Batch Constrained Q-learning

class DDRQN:
    def __init__(self) -> None:
        """Initialize DDRQN with model, target model, and replay memory."""
        self.replay_memory = deque(maxlen=MAX_SIZE_MEMORY)
        self.per = prioritized_replay.PER(MAX_SIZE_MEMORY, ABS_ERROR_UPPER) if PER else None
        self.bcq = bcq.GenerativeModel() if BCQ else None
        self.model = create_model(STATE_SPACE)
        self.target_model = create_model(STATE_SPACE)
        self.target_model.set_weights(self.model.get_weights())
        
        # Save initial model (single model for DDRQN/BCQ)
        try:
            self.model.save(MODEL_PATH+'/0', save_format='tf')
            logging.info("Initial model saved to %s", MODEL_PATH+'/0')
        except Exception as e:
            logging.error("Failed to save initial model: %s", e)

        logging.info("Double DQN/Batch Constrained Q-learning initialized")

    def _prepare_minibatch(self) -> tuple:
        """Prepare a minibatch from replay memory, either with PER or random sampling."""
        batch = []
        for i in range(STATE_SIZE, self.replay_memory_size + 1):
            expanded_state = [self.replay_memory[i - STATE_SIZE + j][0] for j in range(STATE_SIZE)]
            expanded_next_state = [self.replay_memory[i - STATE_SIZE + j][2] for j in range(STATE_SIZE)]
            transition = self.replay_memory[i - 1]
            # Use tuple instead of list
            batch.append((expanded_state, transition[1], expanded_next_state, transition[3], transition[4]))

        random.shuffle(batch)
        self.batch_size = len(batch)

        if PER:
            for transition in batch:
                max_priority = self.per.get_max() or ABS_ERROR_UPPER
                self.per.add(transition, max_priority)
            b_idx, minibatch, ISWeights = self.per.sample(MINIBATCH_SIZE)
            return b_idx, minibatch, ISWeights
        else:
            minibatches = [batch[i:i + MINIBATCH_SIZE] for i in range(0, len(batch), MINIBATCH_SIZE)]
            return None, minibatches, None

    @tf.function(reduce_retracing=True)
    def _predict(self, states: tf.Tensor, model) -> tf.Tensor:
        """Predict Q-values using the given model."""
        states = tf.convert_to_tensor(states, dtype=tf.float32)
        return model(states, training=False)

    def _update_q_values(self, current_qs: np.ndarray, next_qs: np.ndarray, future_qs: np.ndarray, 
                        actions: list, rewards: list, terminals: list, next_states: np.ndarray) -> np.ndarray:
        """Update Q-values using Double DQN or BCQ logic."""
        qs_target = current_qs.numpy().copy()  
        for i in range(MINIBATCH_SIZE):
            action_idx = ACTIONS.index(actions[i])
            if terminals[i]:
                qs_target[i, action_idx] = rewards[i]
            else:
                if BCQ and self.bcq:
                    max_future_q = self.bcq.get_qmax(next_states[i, STATE_SIZE - 1], future_qs[i])
                else:  # Double DQN
                    max_action = np.argmax(next_qs[i])
                    max_future_q = future_qs[i, max_action]
                qs_target[i, action_idx] = rewards[i] + DISCOUNT_FACTOR * max_future_q
        return qs_target

    def feed_memory(self) -> None:
        """Load transitions into replay memory."""
        self.replay_memory_size = feed_memory(self.replay_memory)

    def train_model(self) -> None:
        """Train the DDRQN/BCQ model using replay memory."""
        if not self.replay_memory:
            logging.error("Replay memory is empty. Call feed_memory() first.")
            return

        update_counter = 0
        minibatch_counter = 0
        samples_counter = self.replay_memory_size

        b_idx, minibatches, ISWeights = self._prepare_minibatch()

        while True:
            if PER:
                if samples_counter * EPOCHS < MINIBATCH_SIZE:
                    break
                b_idx, minibatch, ISWeights = self.per.sample(MINIBATCH_SIZE)
            else:
                if minibatch_counter >= len(minibatches) - 1:
                    break
                minibatch = minibatches[minibatch_counter]
                minibatch_counter += 1

            current_states = tf.convert_to_tensor(
                np.array([np.reshape(t[0], (STATE_SIZE, STATE_SPACE)) for t in minibatch]), 
                dtype=tf.float32
            )
            next_states = tf.convert_to_tensor(
                np.array([np.reshape(t[2], (STATE_SIZE, STATE_SPACE)) for t in minibatch]), 
                dtype=tf.float32
            )
            actions = [t[1] for t in minibatch]
            rewards = [t[3] for t in minibatch]
            terminals = [t[4] for t in minibatch]

            current_qs = self._predict(current_states, self.model)
            next_qs = self._predict(next_states, self.model)
            future_qs = self._predict(next_states, self.target_model)

            # Update Q-values
            qs_target = self._update_q_values(current_qs, next_qs, future_qs, actions, rewards, terminals, next_states)

            # Train model
            if PER:
                action_indices = [ACTIONS.index(a) for a in actions]
                abs_errors = np.abs(qs_target[range(MINIBATCH_SIZE), action_indices] - current_qs.numpy()[range(MINIBATCH_SIZE), action_indices])
                self.model.fit(current_states, qs_target, batch_size=MINIBATCH_SIZE, epochs=EPOCHS, 
                              sample_weight=ISWeights, verbose=0)
                self.per.batch_update(b_idx, abs_errors)
            else:
                self.model.fit(current_states, qs_target, batch_size=MINIBATCH_SIZE, epochs=EPOCHS, verbose=0)

            update_counter += MINIBATCH_SIZE
            samples_counter -= MINIBATCH_SIZE if PER else 0

            # Update target network
            if update_counter > UPDATE_TARGET:
                self.target_model.set_weights(self.model.get_weights())
                update_counter = 0
                logging.info("Target network updated")

        # Save model (single model for DDRQN/BCQ)
        try:
            self.model.save(MODEL_PATH+'/0', save_format='tf')
            logging.info("Model saved to %s", MODEL_PATH+'/0')
        except Exception as e:
            logging.error("Failed to save model: %s", e)

if __name__ == "__main__":
    ddrqn_instance = DDRQN()
    ddrqn_instance.feed_memory()
    ddrqn_instance.train_model()
