import sys
import os

sys.path.append(os.path.abspath("../../")) 
sys.path.append(os.path.abspath("../../..")) 

import importlib
import torch
import src.utils.data as data
import commons.utils as utils
import src.models.vae.vae_move as vae_move
import commons.semantic_loss as sem_loss
from semantic_loss_pytorch import SemanticLoss
import importlib
import numpy as np
import csv

importlib.reload(vae_move)
importlib.reload(sem_loss)
importlib.reload(data)
importlib.reload(utils)


device = "cuda:0" if torch.cuda.is_available() else "cpu"

# ---- EXPERIMENT PARAMETERS ------

# It is sufficient to change this parameter
experiment_loss = sem_loss.SemanticExperiment.PERIODIC_2


TEST_SIZE = 100
train_data, val_data, test_data = data.get_x_y_data(invd_steps=TEST_SIZE, device=device, val_split=None)

x_train, y_train = train_data
x_test, y_test = test_data

vae_simulator = torch.load("../trained_models/vae_simulator.pt", weights_only=False).to(device)
simulator = torch.load("../trained_models/simulator.pt", weights_only=False).to(device)
vae = torch.load("../trained_models/vae.pt", weights_only=False).to(device)

# --------- BEGIN EXPERIMENTS ---------------

torch.set_printoptions(sci_mode=True)
np.set_printoptions(linewidth=np.inf) # type: ignore


# ---- CREATE LOG FILES -------
if not os.path.exists("logs/"):
    os.makedirs("logs/")


file_no_loss = open("logs/[][No_loss]results.csv", "w")
file_loss = open(f"logs/[{experiment_loss.get_log_filenames()[0]}][Sem_loss]results.csv", "w")


# Content description of csv file
# 0: learning rate
# 1: idx of material of the test set 
# 2: current epoch 
# 3: idx of the generated material (out of 128 for every epoch)
# 4: individual simulator loss (sum of squared errors -> during analysis compute sqrt(mean(simulator loss)) to obtain srmse)
# 5: individual semantic loss
# 6: one-hot accuracy in [0,1] of the generated point
# 7: a tensor representing the effective generated material 
header = ["Lr", "Mat_idx", "Epochs", "Point", "Simulator loss", "Semantic loss", "Onehot", "Decoded mat"]
csv.writer(file_no_loss).writerow(header)
csv.writer(file_loss).writerow(header)



x, manager, vtree = sem_loss.construct_vars(10, 7)
# Define the formula used by semantic loss
formula = experiment_loss.get_constraint_function(num_lay = 10, num_mat = 7)(x) # type: ignore

manager.save(b'formula.sdd', formula)
vtree.save(b'order.vtree')
sloss = SemanticLoss('formula.sdd', 'order.vtree')


torch.manual_seed(0)
test_data_x = x_test[:TEST_SIZE, :].to(device) 
test_data_y = y_test[:TEST_SIZE, :].to(device)

mat_idx = 0
history = []
for (test_point_x, test_point_y) in list(zip(test_data_x, test_data_y)):
    test_point_x = test_point_x.unsqueeze(0)
    test_point_y = test_point_y.unsqueeze(0)

    NUM_POINTS = 128    
    initial_point = torch.randn((NUM_POINTS, vae.latent_dim)).to(device)

    history_noloss = vae_move.search_point_multi(vae, vae_simulator, initial_point, test_point_y, test_point_x, learning_rate=0.01, epochs=200)
    history_semloss = vae_move.search_point_multi(vae, vae_simulator, initial_point, test_point_y, test_point_x, learning_rate=0.01, epochs=200, sloss=sloss)

    for epoch in range(len(history_noloss)):
        for point in range(len(history_noloss[epoch])):
            row_noloss = [0.01, mat_idx, epoch, point, history_noloss[epoch][point][1], history_noloss[epoch][point][2], history_noloss[epoch][point][3], history_noloss[epoch][point][5]]
            row_semloss = [0.01, mat_idx, epoch, point, history_semloss[epoch][point][1], history_semloss[epoch][point][2], history_semloss[epoch][point][3], history_semloss[epoch][point][5]]
            
            csv.writer(file_no_loss).writerow(row_noloss)
            csv.writer(file_loss).writerow(row_semloss)

    file_no_loss.flush()
    file_loss.flush()

    mat_idx += 1