'''
Maxmin Q-learning algorithm

Code related to the paper: Maxmin Q-learning: Controlling the Estimation Bias of Q-learning
'''

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Suppress TensorFlow warnings

import random
import tensorflow as tf
import numpy as np
from collections import deque
from .Utils import prioritized_replay
from .Utils.rl_utils import (MODEL_PATH, STATE_SPACE, STATE_SIZE, MINIBATCH_SIZE, UPDATE_TARGET, EPOCHS, PER, ABS_ERROR_UPPER,
                             N_MODELS, DISCOUNT_FACTOR, ACTIONS, MAX_SIZE_MEMORY, create_model, feed_memory, save_models,
                             logging)

class MAXMIN:
    def __init__(self) -> None:
        """Initialize MAXMIN with ensemble of models and replay memory."""
        self.replay_memory = deque(maxlen=MAX_SIZE_MEMORY)
        self.per = prioritized_replay.PER(MAX_SIZE_MEMORY, ABS_ERROR_UPPER) if PER else None
        self.models = [create_model(STATE_SPACE) for _ in range(N_MODELS)]
        self.target_models = [create_model(STATE_SPACE) for _ in range(N_MODELS)]
        
        for m in range(N_MODELS):
            self.target_models[m].set_weights(self.models[m].get_weights())
            self.models[m].save(os.path.join(MODEL_PATH, str(m)), save_format='tf')

        logging.info("Maxmin Q-learning initialized with %d models", N_MODELS)

    def _prepare_minibatch(self) -> tuple:
        """Prepare a minibatch from replay memory, either with PER or random sampling."""
        batch = []
        for i in range(STATE_SIZE, self.replay_memory_size + 1):
            expanded_state = [self.replay_memory[i - STATE_SIZE + j][0] for j in range(STATE_SIZE)]
            expanded_next_state = [self.replay_memory[i - STATE_SIZE + j][2] for j in range(STATE_SIZE)]
            transition = self.replay_memory[i - 1]
            batch.append((expanded_state, transition[1], expanded_next_state, transition[3], transition[4]))

        random.shuffle(batch)
        self.batch_size = len(batch)

        if PER:
            for transition in batch:
                max_priority = self.per.get_max() or ABS_ERROR_UPPER
                self.per.add(transition, max_priority)
            b_idx, minibatch, ISWeights = self.per.sample(MINIBATCH_SIZE)
            return b_idx, minibatch, ISWeights
        else:
            minibatches = [batch[i:i + MINIBATCH_SIZE] for i in range(0, len(batch), MINIBATCH_SIZE)]
            return None, minibatches, None

    @tf.function(reduce_retracing=True)
    def _predict_ensemble(self, states: tf.Tensor, models: list) -> tf.Tensor:
        """Predict Q-values using all models in the ensemble."""
        states = tf.convert_to_tensor(states, dtype=tf.float32)
        return tf.stack([model(states, training=False) for model in models])

    def _update_q_values(self, current_qs_ensemble: np.ndarray, future_qs_ensemble: np.ndarray, actions: list, 
                         rewards: list, terminals: list) -> tuple:
        """Update Q-values using the min across ensemble estimates for a single model."""
        # Convert tensors to NumPy for easier manipulation
        current_qs_ensemble = current_qs_ensemble.numpy()  # Shape: [N_MODELS, MINIBATCH_SIZE, num_actions]
        future_qs_ensemble = future_qs_ensemble.numpy()  # Shape: [N_MODELS, MINIBATCH_SIZE, num_actions]
        
        # Select one model to update
        updated_index = random.choice(range(N_MODELS))
        
        # Compute minimum Q-values across target models
        future_qs_min = np.min(future_qs_ensemble, axis=0)  # Shape: [MINIBATCH_SIZE, num_actions]
        
        # Compute target Q-values
        qs_target = np.copy(current_qs_ensemble[updated_index])  # Shape: [MINIBATCH_SIZE, num_actions]
        abs_errors = np.zeros(MINIBATCH_SIZE) if PER else None

        for i in range(MINIBATCH_SIZE):
            action_idx = ACTIONS.index(actions[i])
            if terminals[i]:
                target_q = rewards[i]
            else:
                max_future_q = np.max(future_qs_min[i])  # Best action from min Q-values
                target_q = rewards[i] + DISCOUNT_FACTOR * max_future_q
            
            # Compute TD error for PER
            if PER:
                abs_errors[i] = np.abs(target_q - current_qs_ensemble[updated_index, i, action_idx])
            
            # Update target Q-value
            qs_target[i, action_idx] = target_q

        return qs_target, updated_index, abs_errors

    def feed_memory(self) -> None:
        """Load transitions into replay memory."""
        self.replay_memory_size = feed_memory(self.replay_memory)

    def train_model(self) -> None:
        """Train the MAXMIN ensemble using replay memory."""
        if not self.replay_memory:
            logging.error("Replay memory is empty. Call feed_memory() first.")
            return

        update_counter = 0
        minibatch_counter = 0
        samples_counter = self.replay_memory_size

        b_idx, minibatches, ISWeights = self._prepare_minibatch()

        while True:
            if PER:
                if samples_counter * EPOCHS < MINIBATCH_SIZE:
                    break
                b_idx, minibatch, ISWeights = self.per.sample(MINIBATCH_SIZE)
            else:
                if minibatch_counter >= len(minibatches) - 1:
                    break
                minibatch = minibatches[minibatch_counter]
                minibatch_counter += 1

            current_states = tf.convert_to_tensor(
                np.array([np.reshape(t[0], (STATE_SIZE, STATE_SPACE)) for t in minibatch]), 
                dtype=tf.float32
            )
            next_states = tf.convert_to_tensor(
                np.array([np.reshape(t[2], (STATE_SIZE, STATE_SPACE)) for t in minibatch]), 
                dtype=tf.float32
            )
            actions = [t[1] for t in minibatch]
            rewards = [t[3] for t in minibatch]
            terminals = [t[4] for t in minibatch]

            current_qs_ensemble = self._predict_ensemble(current_states, self.models)
            future_qs_ensemble = self._predict_ensemble(next_states, self.target_models)

            # Update Q-values
            qs_target, updated_index, abs_errors = self._update_q_values(
                current_qs_ensemble, future_qs_ensemble, actions, rewards, terminals)

            # Train the selected model
            if PER:
                self.models[updated_index].fit(
                    current_states, qs_target, batch_size=MINIBATCH_SIZE, 
                    epochs=EPOCHS, sample_weight=ISWeights, verbose=0
                )
                self.per.batch_update(b_idx, abs_errors)
            else:
                self.models[updated_index].fit(
                    current_states, qs_target, batch_size=MINIBATCH_SIZE, 
                    epochs=EPOCHS, verbose=0
                )

            update_counter += MINIBATCH_SIZE
            samples_counter -= MINIBATCH_SIZE if PER else 0

            # Update target networks
            if update_counter > UPDATE_TARGET:
                for m in range(N_MODELS):
                    self.target_models[m].set_weights(self.models[m].get_weights())
                update_counter = 0
                logging.info("Target networks updated")

        save_models(self.models)

if __name__ == "__main__":
    maxmin_instance = MAXMIN()
    maxmin_instance.feed_memory()
    maxmin_instance.train_model()
