import numpy as np
from scipy import io
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import torch.onnx
import onnxruntime as ort

# Set the random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Load data
X = io.loadmat("X_sh.mat")['X']
xdot = io.loadmat("y_sh.mat")['y']

class NoScaler:
    def fit_transform(self, X, y=None):
        return X
    
    def transform(self, X, y=None):
        return X
    
    def inverse_transform(self, X, y=None):
        return X
scaler = NoScaler()

X_train = scaler.fit_transform(X)

xdot_train=xdot #for limited data case

X_train, xdot_train= map(
    torch.tensor, (X_train, xdot_train)
)

train_dataset = TensorDataset(X_train.float(), xdot_train.float())
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=False)

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(X_train.shape[1], 5)
        self.fc2 = nn.Linear(5,5)
        self.output = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = self.output(x)
        return x
    
model = NeuralNetwork()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

# forward invariance
def physics_loss_fun_FI(model,col1,col2,col3,col4):

    output_FI_1 = model(col1)
    loss1 = torch.mean(torch.relu(-output_FI_1[:,0]))
    
    output_FI_2 = model(col2)
    loss2 = torch.mean(torch.relu(output_FI_2[:,0]))
    
    output_FI_3 = model(col3)
    loss3 = torch.mean(torch.relu(-output_FI_3[:,1]))
    
    output_FI_4 = model(col4)
    loss4 = torch.mean(torch.relu(output_FI_4[:,1]))
    
    return loss1 + loss2 + loss3 + loss4

def A(w):

    A_matrices = torch.stack([
        4 + w[:, 0],  1 + w[:, 1],
         3 + w[:, 2], 3 + torch.zeros_like(w[:, 0])
    ], dim=1)

    # Reshape to have a batch of 2x2 matrices.
    A_matrices = A_matrices.view(-1, 2, 2)
    return A_matrices

def physics_loss_fun_PC(model,domainPC_nonscaled,domainPC_scaled):
    # Calculate A_w for the entire batch, which should be a batch of 2x2 matrices
    A_w_batch = A(domainPC_nonscaled[:,2:])
    B_w_batch = A_w_batch
    
    # Compute u1 and u2 for the entire batch
    u1_1 = domainPC_nonscaled[:,1] * (A_w_batch[:, 0, 0] - A_w_batch[:, 0, 1]) + A_w_batch[:, 0, 1]
    u1_2 = domainPC_nonscaled[:,1] * (A_w_batch[:, 1, 0] - A_w_batch[:, 1, 1]) + A_w_batch[:, 1, 1]

    u2_1 = domainPC_nonscaled[:,0] * (B_w_batch[:, 0, 0] - B_w_batch[:, 0, 1]) + B_w_batch[:, 0, 1]
    u2_2 = domainPC_nonscaled[:,0] * (B_w_batch[:, 1, 0] - B_w_batch[:, 1, 1]) + B_w_batch[:, 1, 1]

    # Compute the model output for the entire batch
    output_PC = model(domainPC_scaled)
    
    # Compute the loss for the entire batch
    loss1 = torch.mean(torch.relu(-(u1_1 - u1_2) * output_PC[:, 0]))
    loss2 = torch.mean(torch.relu(-(u2_1 - u2_2) * output_PC[:, 1]))
    
    # Sum the losses
    total_loss = loss1 + loss2
    
    return total_loss

domain = [0,2]
lambda1 = 1e-1
lambda2 = 1e-2
lambda3 = 1e-2

#colocation points
col_size = 2500
x1_colocation = torch.rand(col_size, 1)
x2_colocation = torch.rand(col_size, 1)
w_colocation = torch.FloatTensor(col_size,3).uniform_(domain[0], domain[1])

col1 = torch.cat((torch.zeros(col_size, 1), x2_colocation, w_colocation ), dim=1)
col2 = torch.cat((torch.ones(col_size, 1), x2_colocation, w_colocation ), dim=1)
col3 = torch.cat((x1_colocation, torch.zeros(col_size, 1), w_colocation ), dim=1)
col4 = torch.cat((x1_colocation, torch.ones(col_size, 1), w_colocation), dim=1)

col1 = torch.FloatTensor(scaler.transform(col1))
col2 = torch.FloatTensor(scaler.transform(col2))
col3 = torch.FloatTensor(scaler.transform(col3))
col4 = torch.FloatTensor(scaler.transform(col4))

domainPC_nonscaled = torch.cat((x1_colocation, x2_colocation, w_colocation ), dim=1)
domainPC_scaled = torch.FloatTensor(scaler.transform(domainPC_nonscaled ))

# Train model
for epoch in range(800):
    for batch_X, batch_xdot in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        
        reg_loss = criterion(outputs, batch_xdot)
        
        #physics informed loss
        phy_loss_FI = physics_loss_fun_FI(model,col1,col2,col3,col4)
        phy_loss_PC = physics_loss_fun_PC(model,domainPC_nonscaled,domainPC_scaled)
        
        total_loss = lambda1 * reg_loss + lambda2 * phy_loss_FI + lambda3 * phy_loss_PC
        total_loss.backward()
        optimizer.step()
        
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Mse Loss: {reg_loss.item(): .1e}, Phy Loss FI: {phy_loss_FI.item(): .1e}, Phy Loss PC: {phy_loss_PC.item(): .1e}")

model.eval()

# Importing
n_features = X_train.shape[1]
dummy_input = torch.randn(1, n_features)

# Define the file path for the ONNX model
onnx_model_path = "model_sh.onnx"

# Export the model
torch.onnx.export(model,               # model being run
                  dummy_input,         # model input (or a tuple for multiple inputs)
                  onnx_model_path,     # where to save the model (can be a file or file-like object)
                  export_params=True,  # store the trained parameter weights inside the model file
                  opset_version=11,    # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

print("Model has been converted to ONNX")

# Load the ONNX model
model_path = 'model_sh.onnx'
session = ort.InferenceSession(model_path)

# Get model input details
input_name = session.get_inputs()[0].name
input_shape = session.get_inputs()[0].shape
input_type = session.get_inputs()[0].type

print(f"Input name: {input_name}")
print(f"Input shape: {input_shape}")
print(f"Input type: {input_type}")