import numpy as np
from scipy import io
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import torch.onnx
import onnxruntime as ort

# Set the random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Load data
X = io.loadmat("X_rps.mat")['X']
xdot = io.loadmat("y_rps.mat")['y']

class NoScaler:
    def fit_transform(self, X, y=None):
        return X
    
    def transform(self, X, y=None):
        return X
    
    def inverse_transform(self, X, y=None):
        return X
scaler = NoScaler()

X_train = scaler.fit_transform(X)

xdot_train=xdot #for limited data case

X_train, xdot_train= map(
    torch.tensor, (X_train, xdot_train)
)

train_dataset = TensorDataset(X_train.float(), xdot_train.float())
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=False)

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(X_train.shape[1], 5)
        self.fc2 = nn.Linear(5,5)
        self.output = nn.Linear(5, 4)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = self.output(x)
        return x
    
model = NeuralNetwork()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

# forward invariance
def physics_loss_fun_FI(model,col1,col2,col3,col4):

    output_FI_1 = model(col1)
    loss1 = torch.mean(torch.relu(-output_FI_1[:,0]))
    
    output_FI_2 = model(col2)
    loss2 = torch.mean(torch.relu(output_FI_2[:,1]))
    
    output_FI_3 = model(col3)
    loss3 = torch.mean(torch.relu(-output_FI_3[:,2]))
    
    output_FI_4 = model(col4)
    loss4 = torch.mean(torch.relu(output_FI_4[:,3]))
    
    return loss1 + loss2 + loss3 + loss4

def A(w):
    # Assuming `w` is a two-dimensional tensor with shape [batch_size, 4]
    # We will generate a batch of 3x3 matrices A_i, one for each row in w.
    A_matrices = torch.stack([
        torch.full_like(w[:, 0], 0.25), -1 + w[:, 0], 1 + w[:, 1], # First row of each matrix
        1 + w[:, 2],  torch.full_like(w[:, 1], 0.25), -1 + torch.zeros_like(w[:, 1]),  # Second row
        -1 + w[:, 3], 1 + torch.zeros_like(w[:, 2]),  torch.full_like(w[:, 2], 0.25)   # Third row
    ], dim=1)

    # Reshape to have a batch of 3x3 matrices.
    A_matrices = A_matrices.view(-1, 3, 3)
    return A_matrices

def physics_loss_fun_PC(model,domainPC_nonscaled,domainPC_scaled):
    # Calculate A_w for the entire batch, which should be a batch of 2x2 matrices
    A_w_batch = A(domainPC_nonscaled[:,4:])
    B_w_batch = -A_w_batch.transpose(1, 2)
    
    # Compute u1 and u2 for the entire batch
    u1_1 = A_w_batch[:, 0, 0] * domainPC_nonscaled[:,2] + A_w_batch[:, 0, 1] * domainPC_nonscaled[:,3] + A_w_batch[:, 0, 2] * (1 - (domainPC_nonscaled[:,2]+domainPC_nonscaled[:,3]))
    u1_2 = A_w_batch[:, 1, 0] * domainPC_nonscaled[:,2] + A_w_batch[:, 1, 1] * domainPC_nonscaled[:,3] + A_w_batch[:, 1, 2] * (1 - (domainPC_nonscaled[:,2]+domainPC_nonscaled[:,3]))
    u1_3 = A_w_batch[:, 2, 0] * domainPC_nonscaled[:,2] + A_w_batch[:, 2, 1] * domainPC_nonscaled[:,3] + A_w_batch[:, 2, 2] * (1 - (domainPC_nonscaled[:,2]+domainPC_nonscaled[:,3]))

    u2_1 = B_w_batch[:, 0, 0] * domainPC_nonscaled[:,0] + B_w_batch[:, 0, 1] * domainPC_nonscaled[:,1] + B_w_batch[:, 0, 2] * (1 - (domainPC_nonscaled[:,0]+domainPC_nonscaled[:,1]))
    u2_2 = B_w_batch[:, 1, 0] * domainPC_nonscaled[:,0] + B_w_batch[:, 1, 1] * domainPC_nonscaled[:,1] + B_w_batch[:, 1, 2] * (1 - (domainPC_nonscaled[:,0]+domainPC_nonscaled[:,1]))
    u2_3 = B_w_batch[:, 2, 0] * domainPC_nonscaled[:,0] + B_w_batch[:, 2, 1] * domainPC_nonscaled[:,1] + B_w_batch[:, 2, 2] * (1 - (domainPC_nonscaled[:,0]+domainPC_nonscaled[:,1]))   
    
    # Compute the model output for the entire batch
    output_PC = model(domainPC_scaled)
    
    # Compute the loss for the entire batch
    loss1 = torch.mean(torch.relu(-((u1_1 - u1_3) * output_PC[:, 0] + (u1_2 - u1_3) * output_PC[:, 1])))
    loss2 = torch.mean(torch.relu(-((u2_1 - u2_3) * output_PC[:, 2] + (u2_2 - u2_3) * output_PC[:, 3])))
    
    # Sum the losses
    total_loss = loss1 + loss2
    
    return total_loss

domain = [-1,1]
lambda1 = 1e-1
lambda2 = 1e-2
lambda3 = 1e-2

#colocation points
col_size = 2500
x1_colocation = torch.rand(col_size, 1)#1
x2_colocation = torch.rand(col_size, 1)#2
x3_colocation = torch.rand(col_size, 1)#2

y1_colocation = torch.rand(col_size, 1)#3
y2_colocation = torch.rand(col_size, 1)#4
y3_colocation = torch.rand(col_size, 1)#4


sum_colocation_x1x2x3 = x1_colocation + x2_colocation + x3_colocation
x1_normalized = x1_colocation / sum_colocation_x1x2x3#5
x2_normalized = x2_colocation / sum_colocation_x1x2x3

sum_colocation_y1y2y3 = y1_colocation + y2_colocation + y3_colocation
y1_normalized = y1_colocation / sum_colocation_y1y2y3#7
y2_normalized = y2_colocation / sum_colocation_y1y2y3

w_colocation = torch.FloatTensor(col_size,4).uniform_(domain[0], domain[1])

col1 = torch.cat((torch.zeros(col_size, 1), x2_colocation, y1_normalized, y2_normalized, w_colocation ), dim=1)
col2 = torch.cat((x1_colocation, torch.zeros(col_size, 1), y1_normalized, y2_normalized, w_colocation ), dim=1)

col3 = torch.cat((x1_normalized, x2_normalized, torch.zeros(col_size, 1), y2_colocation, w_colocation ), dim=1)
col4 = torch.cat((x1_normalized, x2_normalized, y1_colocation, torch.zeros(col_size, 1), w_colocation), dim=1)

col5 = torch.cat((x1_colocation, 1-x1_colocation, y1_normalized, y2_normalized, w_colocation ), dim=1)
col6 = torch.cat((x1_normalized, x2_normalized, y1_colocation, 1-y1_colocation, w_colocation ), dim=1)

col1 = torch.FloatTensor(scaler.transform(col1))
col2 = torch.FloatTensor(scaler.transform(col2))
col3 = torch.FloatTensor(scaler.transform(col3))
col4 = torch.FloatTensor(scaler.transform(col4))
col5 = torch.FloatTensor(scaler.transform(col5))
col6 = torch.FloatTensor(scaler.transform(col6))

domainPC_nonscaled = torch.cat((x1_normalized, x2_normalized, y1_normalized, y2_normalized, w_colocation ), dim=1)
domainPC_scaled = torch.FloatTensor(scaler.transform(domainPC_nonscaled ))

# Train model
for epoch in range(900):
    for batch_X, batch_xdot in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        
        reg_loss = criterion(outputs, batch_xdot)
        
        #physics informed loss
        phy_loss_FI = physics_loss_fun_FI(model,col1,col2,col3,col4)
        phy_loss_PC = physics_loss_fun_PC(model,domainPC_nonscaled,domainPC_scaled)
        
        total_loss = lambda1 * reg_loss + lambda2 * phy_loss_FI + lambda3 * phy_loss_PC
        total_loss.backward()
        optimizer.step()
        
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Mse Loss: {reg_loss.item(): .1e}, Phy Loss FI: {phy_loss_FI.item(): .1e}, Phy Loss PC: {phy_loss_PC.item(): .1e}")

model.eval()

# Importing
n_features = X_train.shape[1]
dummy_input = torch.randn(1, n_features)

# Define the file path for the ONNX model
onnx_model_path = "model_rps.onnx"

# Export the model
torch.onnx.export(model,               # model being run
                  dummy_input,         # model input (or a tuple for multiple inputs)
                  onnx_model_path,     # where to save the model (can be a file or file-like object)
                  export_params=True,  # store the trained parameter weights inside the model file
                  opset_version=11,    # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

print("Model has been converted to ONNX")

# Load the ONNX model
model_path = 'model_rps.onnx'
session = ort.InferenceSession(model_path)

# Get model input details
input_name = session.get_inputs()[0].name
input_shape = session.get_inputs()[0].shape
input_type = session.get_inputs()[0].type

print(f"Input name: {input_name}")
print(f"Input shape: {input_shape}")
print(f"Input type: {input_type}")