# rem.py
'''
Random Ensemble Mixture algorithm
Code related to the paper: An Optimistic Perspective on Offline Reinforcement Learning

'''

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Suppress TensorFlow warnings

import random
import numpy as np
import tensorflow as tf
from collections import deque
from .Utils import prioritized_replay
from .Utils.rl_utils import (MODEL_PATH, STATE_SPACE, STATE_SIZE, MINIBATCH_SIZE, UPDATE_TARGET, EPOCHS, PER, ABS_ERROR_UPPER,
                             N_MODELS, DISCOUNT_FACTOR, ACTIONS, MAX_SIZE_MEMORY, create_model, feed_memory, save_models,
                             logging)

class REM:
    def __init__(self) -> None:
        """Initialize REM with ensemble of models and replay memory."""
        self.replay_memory = deque(maxlen=MAX_SIZE_MEMORY)
        self.per = prioritized_replay.PER(MAX_SIZE_MEMORY, ABS_ERROR_UPPER) if PER else None
        self.models = [create_model(STATE_SPACE) for _ in range(N_MODELS)]
        self.target_models = [create_model(STATE_SPACE) for _ in range(N_MODELS)]
        
        for m in range(N_MODELS):
            self.target_models[m].set_weights(self.models[m].get_weights())
            self.models[m].save(os.path.join(MODEL_PATH, str(m)), save_format='tf')

        logging.info("Random Ensemble Mixture initialized with %d models", N_MODELS)

    def _prepare_minibatch(self) -> tuple:
        """Prepare a minibatch from replay memory, either with PER or random sampling."""
        batch = []
        for i in range(STATE_SIZE, self.replay_memory_size + 1):
            expanded_state = [self.replay_memory[i - STATE_SIZE + j][0] for j in range(STATE_SIZE)]
            expanded_next_state = [self.replay_memory[i - STATE_SIZE + j][2] for j in range(STATE_SIZE)]
            transition = self.replay_memory[i - 1]
            batch.append((expanded_state, transition[1], expanded_next_state, transition[3], transition[4]))

        random.shuffle(batch)
        self.batch_size = len(batch)

        if PER:
            for transition in batch:
                max_priority = self.per.get_max() or ABS_ERROR_UPPER
                self.per.add(transition, max_priority)
            b_idx, minibatch, ISWeights = self.per.sample(MINIBATCH_SIZE)
            return b_idx, minibatch, ISWeights
        else:
            minibatches = [batch[i:i + MINIBATCH_SIZE] for i in range(0, len(batch), MINIBATCH_SIZE)]
            return None, minibatches, None

    @tf.function(reduce_retracing=True)
    def _predict_ensemble(self, states: tf.Tensor, models: list, alpha: tf.Tensor) -> tf.Tensor:
        """Predict Q-values using all models in the ensemble with Dirichlet weighting."""
        states = tf.convert_to_tensor(states, dtype=tf.float32)
        predictions = tf.stack([model(states, training=False) for model in models])
        return tf.reduce_sum(predictions * alpha[:, tf.newaxis, tf.newaxis], axis=0)

    def _update_q_values(self, current_qs: tf.Tensor, future_qs: tf.Tensor, actions: list, rewards: list, 
                        terminals: list) -> tf.Tensor:
        """Update Q-values based on rewards and future estimates."""
        # Create a modifiable tensor using tf.Variable
        qs_target = tf.Variable(current_qs, dtype=tf.float32)
        
        for i in range(MINIBATCH_SIZE):
            action_idx = ACTIONS.index(actions[i])
            if terminals[i]:
                qs_target[i, action_idx].assign(rewards[i])
            else:
                max_future_q = tf.reduce_max(future_qs[i])
                qs_target[i, action_idx].assign(rewards[i] + DISCOUNT_FACTOR * max_future_q)
        
        return qs_target

    def feed_memory(self) -> None:
        """Load transitions into replay memory."""
        self.replay_memory_size = feed_memory(self.replay_memory)

    def train_model(self) -> None:
        """Train the REM ensemble using replay memory."""
        if not self.replay_memory:
            logging.error("Replay memory is empty. Call feed_memory() first.")
            return

        update_counter = 0
        minibatch_counter = 0
        samples_counter = self.replay_memory_size

        b_idx, minibatches, ISWeights = self._prepare_minibatch()

        while True:
            if PER:
                if samples_counter * EPOCHS < MINIBATCH_SIZE:
                    break
                b_idx, minibatch, ISWeights = self.per.sample(MINIBATCH_SIZE)
            else:
                if minibatch_counter >= len(minibatches) - 1:
                    break
                minibatch = minibatches[minibatch_counter]
                minibatch_counter += 1

            current_states = tf.convert_to_tensor(
                np.array([np.reshape(t[0], (STATE_SIZE, STATE_SPACE)) for t in minibatch]), 
                dtype=tf.float32
            )
            next_states = tf.convert_to_tensor(
                np.array([np.reshape(t[2], (STATE_SIZE, STATE_SPACE)) for t in minibatch]), 
                dtype=tf.float32
            )
            actions = [t[1] for t in minibatch]
            rewards = [t[3] for t in minibatch]
            terminals = [t[4] for t in minibatch]

            alpha = tf.convert_to_tensor(np.random.dirichlet(np.ones(N_MODELS), size=1)[0], dtype=tf.float32)
            
            current_qs = self._predict_ensemble(current_states, self.models, alpha)
            future_qs = self._predict_ensemble(next_states, self.target_models, alpha)

            qs_target = self._update_q_values(current_qs, future_qs, actions, rewards, terminals)

            if PER:
                action_indices = tf.convert_to_tensor([ACTIONS.index(a) for a in actions], dtype=tf.int32)
                indices = tf.range(MINIBATCH_SIZE, dtype=tf.int32)
                abs_errors = tf.abs(
                    tf.gather_nd(qs_target, tf.stack([indices, action_indices], axis=1)) -
                    tf.gather_nd(current_qs, tf.stack([indices, action_indices], axis=1))
                )
                for m in range(N_MODELS):
                    self.models[m].fit(current_states, qs_target, batch_size=MINIBATCH_SIZE, epochs=EPOCHS, 
                                     sample_weight=ISWeights, verbose=0)
                self.per.batch_update(b_idx, abs_errors.numpy())
            else:
                for m in range(N_MODELS):
                    self.models[m].fit(current_states, qs_target, batch_size=MINIBATCH_SIZE, epochs=EPOCHS, 
                                     verbose=0)

            update_counter += MINIBATCH_SIZE
            samples_counter -= MINIBATCH_SIZE if PER else 0

            if update_counter > UPDATE_TARGET:
                for m in range(N_MODELS):
                    self.target_models[m].set_weights(self.models[m].get_weights())
                update_counter = 0
                logging.info("Target networks updated")

        save_models(self.models)

if __name__ == "__main__":
    rem_instance = REM()
    rem_instance.feed_memory()
    rem_instance.train_model()
