from re import X
import sys
import os
sys.path.append(os.path.abspath("../../")) 
sys.path.append(os.path.abspath("../../../")) 
sys.path.append(os.path.abspath("../")) 

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from dataset_5_layer.configs.metamat_ds_config import MetamatDsConfig
from dataset_5_layer.configs.experiment_config import ExpConfig
from dataset_5_layer.data_utils.load_data import get_x_rt_data
from dataset_5_layer.data_utils.utils import create_train_val_split, scaler, unscaler, evaluate_reconstruction
from dataset_5_layer.base_models.autoencoder import Encoder, Decoder, train_AE
from dataset_5_layer.base_models.simulator_Nf import ForwardSimulator

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    metamat_config = MetamatDsConfig()
    metamat_config.print_config()
    exp_config = ExpConfig()

    n_layer = metamat_config.num_lay
    n_mat = metamat_config.num_mat

    X_train, X_test, Y_train, Y_test = get_x_rt_data(metamat_config, exp_config.invd_num_test)

    X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
    Y_train = torch.tensor(Y_train, dtype=torch.float32).to(device)
    X_test = torch.tensor(X_test, dtype=torch.float32).to(device)
    Y_test = torch.tensor(Y_test, dtype=torch.float32).to(device)

    x_train, x_test, y_train, y_test = create_train_val_split(X_train, Y_train, seed=42)

    scaler(x_train, n_layer)

    x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
    
    encoder = Encoder(n_layer, n_mat)
    decoder = Decoder(n_layer, n_mat)

    id = train_AE(encoder, decoder, x_train, n_layer, 150, 0.001, device, batch_size=1024, REG_WEIGHT=0.0, log=False)  # type: ignore

    reconstructed = decoder(encoder(x_test))

    print("#### AE PERFORMANCES #####")
    evaluate_reconstruction(x_test, reconstructed, n_layer, n_mat, log_file=id)

    inp = input('Do you want to save this model? [y/n]: ')
    if 'y' in inp:
        torch.save(encoder.state_dict(), f'encoder_{n_layer}_{id}.pt')
        torch.save(decoder.state_dict(), f'decoder_{n_layer}_{id}.pt')


    unscaler(x_train, n_layer)

if __name__ == '__main__':
    main()


