import sys
import os

sys.path.append(os.path.abspath("../../")) 
sys.path.append(os.path.abspath("../../..")) 

import torch
import torch.nn as nn

from configs.metamat_ds_config import MetamatDsConfig
from data_utils.load_data import get_x_rt_data
from data_utils.utils import create_train_val_split, scaler, unscaler
from base_models.simulator_Nf import ForwardSimulator
import commons.semantic_loss as sem_loss
from semantic_loss_pytorch import SemanticLoss
import numpy as np
import csv
from methods.adjoint import neural_adjoint as adj
import importlib

importlib.reload(sem_loss)
importlib.reload(adj)

# Change print options to add one-liner tensor in csv file
torch.set_printoptions(sci_mode=True)
np.set_printoptions(linewidth=np.inf) # type: ignore

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# ---- EXPERIMENT PARAMETERS ------

# It is sufficient to change this parameter
experiment_loss = sem_loss.SemanticExperiment.PALINDROME_2


# ---- LOAD DATASET -------
metamat_config = MetamatDsConfig()
metamat_config.print_config()

n_layer = metamat_config.num_lay
n_material = metamat_config.num_mat

TEST_SIZE = 500
x_train, x_test, y_train, y_test = get_x_rt_data(metamat_config, TEST_SIZE)

x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
y_train = torch.tensor(y_train, dtype=torch.float32).to(device)
x_test = torch.tensor(x_test, dtype=torch.float32).to(device)
y_test = torch.tensor(y_test, dtype=torch.float32).to(device)

# Change thickness of material from nanometer range to (-1, 1)
scaler(x_train, n_layer)
scaler(x_test, n_layer)

# ---- LOAD MODELS -----

mat_simulator = ForwardSimulator(metamat_config.num_lay, 2) 
mat_simulator.load_state_dict(torch.load("../../trained_models/simulator_noscale/Nf_175200train_43800val_5matlay_150e_1024b_0.005lr.pt"))
mat_simulator = mat_simulator.to(device)


# ---- CREATE LOG FILES -------
if not os.path.exists("logs/"):
    os.makedirs("logs/")


file_no_loss = open("logs/[][No_loss]results.csv", "w")
file_loss = open(f"logs/[{experiment_loss.get_log_filenames()[0]}][Sem_loss]results.csv", "w")


# Content description of csv file
# 0: learning rate
# 1: idx of material of the test set 
# 2: current epoch 
# 3: idx of the generated material (out of 128 for every epoch)
# 4: individual simulator loss (mean of squared errors -> during analysis compute sqrt(simulator loss) to obtain srmse)
# 5: individual semantic loss
# 6: one-hot accuracy in [0,1] of the generated point
# 7: a tensor representing the effective generated material 
header = ["Lr", "Mat_idx", "Epochs", "Point", "Simulator loss", "Semantic loss", "Onehot", "Decoded mat"]
csv.writer(file_no_loss).writerow(header)
csv.writer(file_loss).writerow(header)


# ------------- BEGIN EXPERIMENT --------------------

num_lay = 5
num_mat = 5

# Instantiate the variables that will be associated with x_i_j
x, manager, vtree = sem_loss.construct_vars(num_lay, num_mat)

# Define the formula used by semantic loss
formula = experiment_loss.get_constraint_function(num_lay, num_mat)(x) # type: ignore

manager.save(b'formula.sdd', formula)
vtree.save(b'order.vtree')
sloss = SemanticLoss('formula.sdd', 'order.vtree')


test_data_x = x_test[:TEST_SIZE, :].to(device) # type: ignore
test_data_y = y_test[:TEST_SIZE, :].to(device)

point_dim = num_lay * num_mat + num_lay

mat_idx = 0
history = []
torch.manual_seed(0)
for (test_point_x, test_point_y) in list(zip(test_data_x, test_data_y)):
    test_point_x = test_point_x.unsqueeze(0)
    test_point_y = test_point_y.unsqueeze(0)

    NUM_POINTS = 128
    
    initial_point = torch.randn((NUM_POINTS, point_dim)).to(device)

    history_noloss, losses = adj.neural_adjoint_search(mat_simulator, initial_point, test_point_x, test_point_y, metamat_config, lr=0.05, epochs=200)
    history_semloss, losses_sem = adj.neural_adjoint_search(mat_simulator, initial_point, test_point_x, test_point_y, metamat_config, lr=0.05, epochs=200, sloss=sloss)

    for epoch in range(len(history_noloss)):
        for point in range(len(history_noloss[epoch])):
            row_noloss = [0.05, mat_idx, epoch, point, history_noloss[epoch][point][1], history_noloss[epoch][point][2], history_noloss[epoch][point][3], history_noloss[epoch][point][4]]
            row_semloss = [0.05, mat_idx, epoch, point, history_semloss[epoch][point][1], history_semloss[epoch][point][2], history_semloss[epoch][point][3], history_semloss[epoch][point][4]]
            
            csv.writer(file_no_loss).writerow(row_noloss)
            csv.writer(file_loss).writerow(row_semloss)

    file_no_loss.flush()
    file_loss.flush()
        
    mat_idx += 1