import os
os.environ["KERAS_BACKEND"] = "torch"

import numpy as np
from keras.layers import Input, Dense
from keras.models import Model
from keras import regularizers
import keras
from keras import layers, optimizers, metrics



# 1. Dimensions

def build_model(original_dim, latent_dim):
    # 2. Build encoder with multiple layers
    input_vec = Input(shape=(original_dim,), name='encoder_input')
    _encoder = keras.Sequential([
        layers.InputLayer(shape=(original_dim,)),
        layers.Dense(256, activation='relu', name='layer1'),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu', name='layer2'),
        layers.BatchNormalization(),
        layers.Dense(128, activation='relu', name='layer3'),
        layers.BatchNormalization(),
        layers.Dense(32, activation='relu', name='layer4'),
        layers.BatchNormalization(),
        layers.Dense(8, activation='relu', name='layer5'),
        layers.BatchNormalization(),
        layers.Dense(latent_dim, activation='linear', name='output')
    ])
    _decoder = keras.Sequential([
        layers.InputLayer(shape=(latent_dim,)),
        layers.Dense(8, activation='relu', name='layer1'),
        layers.BatchNormalization(),
        layers.Dense(32, activation='relu', name='layer2'),
        layers.BatchNormalization(),
        layers.Dense(128, activation='relu', name='layer3'),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu', name='layer4'),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu', name='layer5'),
        layers.BatchNormalization(),
        layers.Dense(original_dim, activation='linear', name='output')
    ])

    x = _encoder(input_vec)
    output_vec = _decoder(x)

    # 4. Full autoencoder model
    autoencoder = Model(inputs=input_vec, outputs=output_vec, name='autoencoder')
    autoencoder.compile(optimizer='adam', loss='mse')

    # 5. Encoder model (for feature extraction)
    encoder = Model(inputs=input_vec, outputs=x, name='encoder')

    # 6. Decoder model (for reconstruction from latent)
    latent_input = Input(shape=(latent_dim,), name='decoder_input')
    dec_out = _decoder(latent_input)
    decoder = Model(inputs=latent_input, outputs=dec_out, name='decoder')
    return autoencoder, encoder


def train_ae(X_train, autoencoder):
    # 7. Train
    # X_train: numpy array of shape (n_samples, original_dim)
    autoencoder.fit(
        X_train, X_train,
        epochs=300,
        batch_size=3000,
        validation_split=0.05,
        callbacks=[
            # you can add EarlyStopping, ModelCheckpoint, etc.
        ]
    )
    return autoencoder


def transform(encoder, X_train):
    # 8. Extract hidden features
    hidden_features = encoder.predict(X_train)  # shape = (n_samples, latent_dim)
    return hidden_features
