import sys
import os

sys.path.append(os.path.abspath("../../..")) 
sys.path.append(os.path.abspath("../../")) 
sys.path.append(os.path.abspath("../")) 

import importlib
import torch
import src.utils.data as data
import commons.semantic_loss as sem_loss
from semantic_loss_pytorch import SemanticLoss
from models.adjoint import neural_adjoint as adj
import importlib
import numpy as np
import csv

importlib.reload(sem_loss)
importlib.reload(adj)

torch.set_printoptions(sci_mode=True)
np.set_printoptions(linewidth=10000)
importlib.reload(data)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# ---- EXPERIMENT PARAMETERS ------

# It is sufficient to change this parameter
experiment_loss = sem_loss.SemanticExperiment.PERIODIC_3


TEST_SIZE = 100
train_data, val_data, test_data = data.get_x_y_data(invd_steps=TEST_SIZE, device=device, val_split=None)

num_lay = 10
num_mat = 7

x_train, y_train = train_data
x_test, y_test = test_data

simulator = torch.load("../trained_models/simulator.pt", weights_only=False).to(device)

# ------- BEGIN EXPERIMENT ---------------

x, manager, vtree = sem_loss.construct_vars(10, 7)

# Define the formula used by semantic loss
formula = experiment_loss.get_constraint_function(num_lay = 10, num_mat = 7)(x) # type: ignore

manager.save(b'formula.sdd', formula)
vtree.save(b'order.vtree')
sloss = SemanticLoss('formula.sdd', 'order.vtree')


# ---- CREATE LOG FILES -------
if not os.path.exists("logs/"):
    os.makedirs("logs/")


file_no_loss = open("logs/[][No_loss]results.csv", "w")
file_loss = open(f"logs/[{experiment_loss.get_log_filenames()[0]}][Sem_loss]results.csv", "w")


# Content description of csv file
# 0: learning rate
# 1: idx of material of the test set 
# 2: current epoch 
# 3: idx of the generated material (out of 128 for every epoch)
# 4: individual simulator loss (mean of squared errors -> during analysis compute sqrt(simulator loss) to obtain srmse)
# 5: individual semantic loss
# 6: one-hot accuracy in [0,1] of the generated point
# 7: a tensor representing the effective generated material 
header = ["Lr", "Mat_idx", "Epochs", "Point", "Simulator loss", "Semantic loss", "Onehot", "Decoded mat"]
csv.writer(file_no_loss).writerow(header)
csv.writer(file_loss).writerow(header)


torch.manual_seed(0)

test_data_x = x_test[:TEST_SIZE, :].to(device) # type: ignore
test_data_y = y_test[:TEST_SIZE, :].to(device)

point_dim = num_lay * num_mat + num_lay

mat_idx = 0
history = []
for (test_point_x, test_point_y) in list(zip(test_data_x, test_data_y)):
    test_point_x = test_point_x.unsqueeze(0)
    test_point_y = test_point_y.unsqueeze(0)

    NUM_POINTS = 128
    
    initial_point = torch.randn((NUM_POINTS, point_dim)).to(device)

    history_noloss, losses = adj.neural_adjoint_search(simulator, initial_point, test_point_y, lr=0.1, epochs=200)
    history_semloss, losses_sem = adj.neural_adjoint_search(simulator, initial_point, test_point_y, lr=0.1, epochs=200, sloss=sloss)

    for epoch in range(len(history_noloss)):
        for point in range(len(history_noloss[epoch])):
            row_noloss = [0.1, mat_idx, epoch, point, history_noloss[epoch][point][1], history_noloss[epoch][point][2], history_noloss[epoch][point][3], history_noloss[epoch][point][4]]
            row_semloss = [0.1, mat_idx, epoch, point, history_semloss[epoch][point][1], history_semloss[epoch][point][2], history_semloss[epoch][point][3], history_semloss[epoch][point][4]]
            
            csv.writer(file_no_loss).writerow(row_noloss)
            csv.writer(file_loss).writerow(row_semloss)

    file_no_loss.flush()
    file_loss.flush()
        
    mat_idx += 1