import os
import numpy as np
import tensorflow as tf
import keras_nlp
from tensorflow.keras import regularizers  # type: ignore
from tensorflow.keras.models import Model  # type: ignore
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling1D  # type: ignore
from tensorflow.keras.optimizers import Adam

from utils import hash_observation, log  # type: ignore


class AlphaNet:
    def __init__(self, input_dim, output_dim, num_rounds, type="mlp"):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_rounds = num_rounds  # Used for transformer model
        self.hidden_sizes = [128, 64, 32]
        self.learning_rate = 0.001
        self.value_map = {}  # value network
        self.policy_map = {}  # policy network

        self.value_model = self._build_value_model()
        if type == "mlp":
            self.policy_model = self._build_policy_model()
        elif type == "transformer":
            self.policy_model = self._build_transformer_policy_model()
        else:
            raise ValueError(f"Unsupported model type: {type}")

    def policy_head(self, x):
        x = Dense(
            self.output_dim,
            use_bias=False,
            activation="linear",
            kernel_regularizer=regularizers.l2(0.0001),
            name="policy_head",
            kernel_initializer="glorot_uniform",
        )(x)
        return x

    def value_head(self, x):
        x = Dense(
            1,
            use_bias=False,
            activation="linear",
            kernel_regularizer=regularizers.l2(0.0001),
            name="value_head",
            kernel_initializer="glorot_uniform",
        )(x)
        return x

    def _build_value_model(self):
        inputs = Input(shape=(self.input_dim + 1,), name="input")
        for i, hidden_size in enumerate(self.hidden_sizes):
            if i == 0:
                x = Dense(hidden_size, activation="relu", name="dense_0", kernel_initializer="glorot_uniform")(inputs)
            else:
                x = Dense(hidden_size, activation="relu", name=f"dense_{i}", kernel_initializer="glorot_uniform")(x)

        x = self.value_head(x)
        model = Model(inputs=inputs, outputs=x)
        model.compile(
            loss=tf.keras.losses.Huber(delta=0.5),  # Try huber loss since conflicting rewards
            optimizer=Adam(learning_rate=self.learning_rate),
        )
        return model

    def _build_policy_model(self):
        inputs = Input(shape=(self.input_dim,), name="input")
        for i, hidden_size in enumerate(self.hidden_sizes):
            if i == 0:
                x = Dense(hidden_size, activation="relu", name="dense_0", kernel_initializer="glorot_uniform")(inputs)
            else:
                x = Dense(hidden_size, activation="relu", name=f"dense_{i}", kernel_initializer="glorot_uniform")(x)

        x = self.policy_head(x)
        model = Model(inputs=inputs, outputs=x)

        def custom_loss(y_true, logits):
            return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=logits))

        model.compile(loss=custom_loss, optimizer=Adam(learning_rate=self.learning_rate))
        return model

    def _build_transformer_policy_model(self):
        inputs = Input(shape=(None,), name="input")
        # TODO: avoid hardcoding these 2 values
        """
        This model can accept inputs with various sequence lengths. But the one_hot_length has to be fixed.
        one_hot_length has to be dividable by num_heads, otherwise input after one_hot will be truncated.
        Set one_hot_length = round_offset + self.num_rounds + 1 becuase:
        we have 7 (states) + self.num_rounds (rounds) possible values in input. 
        +1 is for the zero padding in one-hot encoding to make input dividable by num_heads in Transformer
        """
        round_offset = 7
        # Calculated by the range of possible values after shifting round number
        one_hot_length = round_offset + self.num_rounds + 1
        round_number = inputs[:, 0] + round_offset  # shape: [batch_size]
        rest_states = inputs[:, 1:]  # shape: [batch_size, sequence_length - 1]

        # Reconstruct input with shifted round number
        shifted_inputs = tf.concat([tf.expand_dims(round_number, axis=1), rest_states], axis=1)  # [batch_size, sequence_length]
        x = tf.cast(shifted_inputs, tf.int32)

        # One-hot encoding
        x = tf.one_hot(x, depth=one_hot_length)  # shape: [batch_size, sequence_length, one_hot_length]

        x = keras_nlp.layers.TransformerEncoder(intermediate_dim=one_hot_length * 4, num_heads=2)(
            x
        )  # shape: [batch_size, sequence_length, intermediate_dim]

        # Average pooling over the sequence length dimension
        x = GlobalAveragePooling1D()(x)  # shape: [batch_size, intermediate_dim]

        # Final linear layer to get logits for each action
        x = Dense(self.output_dim, activation="linear")(x)  # shape: [batch_size, num_actions]

        model = Model(inputs=inputs, outputs=x)

        def custom_loss(y_true, logits):
            return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=logits))

        model.compile(loss=custom_loss, optimizer=Adam(learning_rate=self.learning_rate))
        return model

    # This function is deprecated but leave it here for reference in case
    def _build_model(self):
        inputs = Input(shape=(self.input_dim,), name="input")
        for i, hidden_size in enumerate(self.hidden_sizes):
            if i == 0:
                x = Dense(hidden_size, activation="relu", name="dense_0")(inputs)
            else:
                x = Dense(hidden_size, activation="relu", name=f"dense_{i}")(x)

        vh = self.value_head(x)
        ph = self.policy_head(x)

        model = Model(inputs=inputs, outputs=[vh, ph])
        model.compile(
            loss={"value_head": "mean_squared_error", "policy_head": tf.nn.softmax_cross_entropy_with_logits},
            optimizer=Adam(learning_rate=self.learning_rate),
            loss_weights={"value_head": 0.5, "policy_head": 0.5},
        )

        return model

    # state: input_state without action
    # action_space: number of possible actions
    def predict_all_values(self, state, action_space):
        all_values = []
        for a in range(action_space):
            value_states = np.append(np.array(state).reshape(1, -1), [[a]], axis=1)
            hashed_state = hash_observation(value_states)
            if hashed_state in self.value_map:
                value = self.value_map[hashed_state]
            else:
                value = self.predict_value(value_states)
                self.value_map[hashed_state] = value
            all_values.append(np.squeeze(value).tolist())
        return all_values

    def predict_value(self, state):
        value = self.value_model.predict(state, verbose=0)
        return value

    def predict_policy(self, state):
        np_state = np.array(state).reshape(1, -1)
        hashed_state = hash_observation(np_state)
        if hashed_state in self.policy_map:
            policy = self.policy_map[hashed_state]
        else:
            policy = self.policy_model.predict(np_state, verbose=0)
            self.policy_map[hashed_state] = policy
        return policy

    def train_value(self, value_buffer, epochs):
        value_states, values, actions = list(zip(*[value_buffer[i] for i in range(len(value_buffer))]))
        value_states = np.array(value_states)
        values = np.array(values)
        actions = np.array(actions)
        value_states = np.hstack((value_states, actions.reshape(-1, 1)))
        print("Train value model")
        self.value_model.fit(value_states, values, epochs=epochs, verbose=2, validation_split=0, batch_size=32)

    def train_policy(self, policy_buffer, epochs):
        policy_states, policies = list(zip(*[policy_buffer[i] for i in range(len(policy_buffer))]))
        policy_states = np.array(policy_states)
        policies = np.array(policies)
        print("Train policy model")
        self.policy_model.fit(policy_states, policies, epochs=epochs, verbose=2, validation_split=0, batch_size=32)

    def train(self, value_buffer, policy_buffer, epochs):
        # self.train_value(value_buffer, epochs)
        self.train_policy(policy_buffer, epochs)

    def save(self, path):
        """Save model weight"""
        self.value_model.save_weights(os.path.join(path, "value"))
        self.policy_model.save_weights(os.path.join(path, "policy"))

    def load(self, path):
        """Load model weight"""
        # self.value_model.load_weights(os.path.join(path, "value"))
        self.policy_model.load_weights(os.path.join(path, "policy"))

    def load_policy_model(self, path):
        self.policy_model.load_weights(os.path.join(path, "policy"))

    def clear_map(self):
        self.value_map.clear()
        self.policy_map.clear()
        # self.simulation_map.clear()
