import sys
import os

sys.path.append(os.path.abspath("../../")) 
sys.path.append(os.path.abspath("../../..")) 

import torch
from configs.metamat_ds_config import MetamatDsConfig
from configs.experiment_config import ExpConfig
from data_utils.load_data import get_x_rt_data
from data_utils.utils import create_train_val_split, scaler, unscaler
from base_models.autoencoder import Encoder, Decoder
from base_models.simulator_Nf import ForwardSimulator
import methods.vae_based.vae_move as vaemove
import importlib
import commons.semantic_loss as semloss
from semantic_loss_pytorch import SemanticLoss
import numpy as np
import csv
import importlib
importlib.reload(vaemove)

# Change print options to add one-liner tensor in csv file
torch.set_printoptions(sci_mode=False, linewidth=300)
np.set_printoptions(linewidth=np.inf)  # type: ignore

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")




# ---- EXPERIMENT PARAMETERS ------

# It is sufficient to change this parameter
experiment_loss = semloss.SemanticExperiment.PALINDROME_2


# ---- LOAD DATASET -------
metamat_config = MetamatDsConfig()
metamat_config.print_config()

n_layer = metamat_config.num_lay

TEST_SIZE = 500
_, x_test, _, y_test = get_x_rt_data(metamat_config, TEST_SIZE)

#x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
#y_train = torch.tensor(y_train, dtype=torch.float32).to(device)
x_test = torch.tensor(x_test, dtype=torch.float32).to(device)
y_test = torch.tensor(y_test, dtype=torch.float32).to(device)

# Change thickness of material from nanometer range to (-1, 1)
#scaler(x_train, n_layer)
scaler(x_test, n_layer)

# ---- LOAD MODELS -----

vae = torch.load("../../trained_models/vae_kl/VAE_LR=0.001_BS=256_E=150.pt", weights_only=False).to(device)
vae_simulator = torch.load("../../trained_models/vae_kl/VAE_LR=0.001_BS=256_E=150_SIMULATOR.pt", weights_only=False).to(device)

latent_dim = vae.latent_dim
test_data_x = x_test[:TEST_SIZE, :].to(device)
test_data_y = y_test[:TEST_SIZE, :].to(device)



# ---- CREATE LOG FILES -------
if not os.path.exists("logs/"):
    os.makedirs("logs/")


file_no_loss = open("logs/[][No_loss]results.csv", "w")
file_loss = open(f"logs/[{experiment_loss.get_log_filenames()[0]}][Sem_loss]results.csv", "w")


# Content description of csv file
# 0: learning rate
# 1: idx of material of the test set 
# 2: current epoch 
# 3: idx of the generated material (out of 128 for every epoch)
# 4: individual simulator loss (sum of squared errors -> during analysis compute sqrt(mean(simulator loss)) to obtain srmse)
# 5: individual semantic loss
# 6: one-hot accuracy in [0,1] of the generated point
# 7: a tensor representing the effective generated material 
header = ["Lr", "Mat_idx", "Epochs", "Point", "Simulator loss", "Semantic loss", "Onehot", "Decoded mat"]
csv.writer(file_no_loss).writerow(header)
csv.writer(file_loss).writerow(header)

# ------------- BEGIN EXPERIMENT --------------------

n_lay = 5
n_mat = 5

# Instantiate the variables that will be associated with x_i_j
x, manager, vtree = semloss.construct_vars(n_lay, n_mat)
formula_str = 'formula.sdd'
vtree_str = 'order.vtree'

# Define the formula used by semantic loss
formula = experiment_loss.get_constraint_function(n_lay, n_mat)(x) # type: ignore

manager.save(formula_str.encode(), formula)
vtree.save(vtree_str.encode())
sloss = SemanticLoss(formula_str, vtree_str)


torch.manual_seed(0)
mat_idx = 0
history = []
for (test_point_x, test_point_y) in list(zip(test_data_x, test_data_y)):
    test_point_x = test_point_x.unsqueeze(0)
    test_point_y = test_point_y.unsqueeze(0)

    NUM_POINTS = 128    
    initial_point = torch.randn((NUM_POINTS, vae.latent_dim)).to(device)

    # Both methods will start optimizing the same initial points
    history_noloss = vaemove.search_point_multi(vae, vae_simulator, initial_point, test_point_y, test_point_x, metamat_config, learning_rate=0.01, epochs=200)
    history_semloss = vaemove.search_point_multi(vae, vae_simulator, initial_point, test_point_y, test_point_x, metamat_config, learning_rate=0.01, epochs=200, sloss=sloss)

    for epoch in range(len(history_semloss)):
        for point in range(len(history_semloss[epoch])):
            row_noloss = [0.01, mat_idx, epoch, point, history_noloss[epoch][point][1], history_noloss[epoch][point][2], history_noloss[epoch][point][3], history_noloss[epoch][point][5]]
            row_semloss = [0.01, mat_idx, epoch, point, history_semloss[epoch][point][1], history_semloss[epoch][point][2], history_semloss[epoch][point][3], history_semloss[epoch][point][5]]
            
            csv.writer(file_no_loss).writerow(row_noloss)
            csv.writer(file_loss).writerow(row_semloss)

    file_no_loss.flush()
    file_loss.flush()

    mat_idx += 1


# -------------- END EXPERIMENT --------------------------