clear all
clc
close all
rng("default")

load('data_JULIA_sh.mat')

T_data = 0.3; %set according to julia data "t_train_data = collect(LinRange(0.0, 0.3, 4))"
deltat_data = 0.01; %set according to julia data "t_all_data = collect(LinRange(0.0, 0.3, 31))"

polynomials_siar = generate_polynomials_siar(all_siar_coeffs, all_siar_terms);
polynomials_sindy = generate_polynomials_sindy(all_sindy_coeffs, all_sindy_terms);
modelfile = "model_sh.onnx";
params = importONNXFunction(modelfile,"PINNFcn_sh");

% Extract weights and biases
nn_w1 = extractdata(params.Learnables.fc1_weight);
nn_b1 = extractdata(params.Learnables.fc1_bias);
nn_w2 = extractdata(params.Learnables.fc2_weight);
nn_b2 = extractdata(params.Learnables.fc2_bias);
nn_w3 = extractdata(params.Learnables.output_weight);
nn_b3 = extractdata(params.Learnables.output_bias);
%% initial conditions 
A0=[4 1;
    3 3];

x_ref=[1 1];

n_state=2;
n_input=3;

Ts=0.1;
Tc =10;
Tp = 10;
lb_s = 0;
ub_s = 1;
lb_c = 0;
ub_c = 2;
weight_s = 1;
weight_c = 0.1;
weight_dc = 0.01;

%% generate the polynomials using Julia data
methods = {'siar', 'sindy', 'pinn'};
polynomials = {polynomials_siar, polynomials_sindy};

my_ests = cell(1,3);
my_Jacobians = cell(1,3);
for idx = 1:3
    
    if idx <= 2
    f_fcn = matlabFunction(polynomials{idx}, 'Vars', {'v1','v2','w1','w2','w3'});
    my_ests{idx} = @(v,w)estimated(v,w, f_fcn);
    % construct Jacobians
    [A_func, Bmv_func] = generateJacobian(polynomials{idx}, n_state, n_input);
    my_Jacobians{idx} = @(v,w)computeJacobian(v,w,A_func,Bmv_func);
    
    else
    my_ests{idx}=@(x,u)my_est_pinn_sh(x,u,params);
    my_Jacobians{idx}=@(x,u)myStateJacobian_PINN(x,u, nn_w1,nn_w2,nn_w3,nn_b1,nn_b2,nn_b3);
    end
end
disp('system models are imported');

%% construct nlmpc for both methods
nlobjs = cell(1, 3);
for idx = 1:3
    method = methods{idx};
    my_est = my_ests{idx};
    my_Jacobian = my_Jacobians{idx};

    % Initialize the MPC object
    nlobj = initializeMPC(my_est, my_Jacobian, n_state, n_input, Ts, Tc, Tp, lb_s, ub_s, lb_c, ub_c, weight_s, weight_c, weight_dc);
    nlobjs{idx} = nlobj;

    % Validate MPC
    x0 = [0.2;0.5];
    w0 = 0*ones(1,n_input)';
    validateFcns(nlobj, x0, w0);

end
 
disp('nlmpc is constructed');
%% training: data generation
x1_data_all = cellfun(@(c) c(1), x_all_data);
x2_data_all = cellfun(@(c) c(2), x_all_data);
 
% Extract x_data from each struct in data and scatter
x1_data = cellfun(@(c) c(1), x_train_data);
x2_data = cellfun(@(c) c(2), x_train_data); 
 
wHistory = [];
for i=1:size(w_all_data,1)
wHistory = [wHistory w_all_data{i}];
end
 
tHistory=[0:deltat_data:T_data];
 
xHistory = [x1_data_all';x2_data_all'];
xHistory_siar = [xHistory(:,1:end-1)*0 xHistory(:,end)];
xHistory_sindy = [xHistory(:,1:end-1)*0 xHistory(:,end)];
xHistory_pinn = [xHistory(:,1:end-1)*0 xHistory(:,end)];

disp('data is exported');
%% compare xdot true, siar and sindy
xdotHistory = cell(1,4);
History_w = [];
History_t = [];
History_x = [xHistory(:,end)];
tspan = [0:0.01:0.1];

for ct = 1:50
w0 =ub_c*rand(n_input,1);

[t,x] = ode45(@(t,x)replicator_sh(x,w0,A0),tspan,History_x(:,end));   
x_dot_values = cell(1,3); 

    for idx = 1:4
        if idx == 4
            for i = 1:length(t)
            x_dot_values{idx}(:, i) = replicator_sh(x(i, :)', w0, A0);
            end
        else
            for i = 1:length(t)
                x_dot_values{idx}(:, i) = my_ests{idx}(x(i, :)', w0);
            end 
        end
    end
    
History_x = [History_x x(2:end,:)'];

    if isempty(History_w) 
        History_w = w0*ones(1,length(t));
    else
        History_w=[History_w w0*ones(1,length(t(2:end)))];
    end
    
    if isempty(History_t) 
        History_t = t';
    else
        History_t = [History_t History_t(end) + t(2:end)'];
    end
    
    for idx = 1:4
    xdotHistory{idx}=[xdotHistory{idx}(:,1:end-1) x_dot_values{idx}];
    end
input(:,ct)=w0;    
end

figure(1)
xdoterror_siar = abs(xdotHistory{1, 1}(1:2,:)-xdotHistory{1, 4}(1:2,:));
xdoterror_sindy= abs(xdotHistory{1, 2}(1:2,:)-xdotHistory{1, 4}(1:2,:));
xdoterror_pinn= abs(xdotHistory{1, 3}(1:2,:)-xdotHistory{1, 4}(1:2,:));

colors = [0 0.4470 0.7410; 0.4660 0.6740 0.1880; 0.9290 0.6940 0.1250; 0.4940 0.1840 0.5560];

subplot(4,1,1)
hold on;
for k = 1:2
    l1(k) = plot(History_t, xdoterror_siar(k,:), 'Color', colors(k,:), 'Linewidth', 2);
end
xlabel('time','Interpreter','latex','FontSize',12)
ylabel('error_siar','Interpreter','latex','FontSize',12)
title('\textbf{Estimation Error in $\dot{x}$}','Interpreter','latex','FontSize',12)

subplot(4,1,2)
hold on;
for k = 1:2
    l2(k) = plot(History_t, xdoterror_sindy(k,:), 'Color', colors(k,:), 'Linewidth', 2,'HandleVisibility', 'off');
end
xlabel('time','Interpreter','latex','FontSize',12)
ylabel('error_sindyc','Interpreter','latex','FontSize',12)

subplot(4,1,3)
hold on;
for k = 1:2
    l3(k) = plot(History_t, xdoterror_pinn(k,:), 'Color', colors(k,:), 'Linewidth', 2,'HandleVisibility', 'off');
end
xlabel('time','Interpreter','latex','FontSize',12)
ylabel('error_pinn','Interpreter','latex','FontSize',12)

subplot(4,1,4)
hold on;

colors = {'#4DBEEE','#1AA640','#E68000','#8040E6'};
for i=1:n_input
l3(i)=stairs(History_t,History_w(i,:),'Linewidth',2,'Color', colors{i});
end
xlabel('time','Interpreter','latex','FontSize',12)
ylabel('inputs','Interpreter','latex','FontSize',12)

%% validation of the estimated dynamics
duration_ctrl = 0.2;
forplot=0;
while tHistory(end)<=2
w0 = ub_c*rand(n_input,1);
    if forplot == 0
       [t_sindy,x_sindy] = ode45(@(t_sindy,x_sindy)my_ests{2}(x_sindy,w0),[0 duration_ctrl],xHistory_sindy(:,end));
        if any(diff(t_sindy) < 1e-10) 
           forplot =1;
        end
    else 
        t_sindy=[0:0.0025:duration_ctrl]';
        x_sindy=( xHistory_sindy(:,end).*ones(2,length(t_sindy)))';
    end

[t_siar,x_siar] = ode45(@(t_siar,x_siar)my_ests{1}(x_siar,w0),t_sindy,xHistory_siar(:,end));%%!!
[t_pinn,x_pinn] = ode45(@(t_pinn,x_pinn)my_ests{3}(x_pinn,w0),t_sindy,xHistory_pinn(:,end));%%!!

[t,x] = ode45(@(t,x)replicator_sh(x,w0,A0),t_sindy,xHistory(:,end));
 
xHistory_siar = [xHistory_siar x_siar(2:end,:)'];
xHistory_sindy = [xHistory_sindy x_sindy(2:end,:)'];
xHistory_pinn = [xHistory_pinn x_pinn(2:end,:)'];

xHistory = [xHistory x(2:end,:)'];
 
 
wHistory=[wHistory w0*ones(1,length(t_sindy)-1)];
tHistory=[tHistory(1:end-1) (tHistory(end)+t_sindy)'];
inputHistory=xHistory(:,end);
end
t_datagen = size(x_all_data,1);
t_valid = length(tHistory)-length(x_all_data) + t_datagen;

disp('validation is completed');

%% initialize the MPC
Duration = 80;
T_real=Ts;
deltat_real=Ts/10;

xHistory_s=xHistory;
wHistory_s=wHistory;
xHistory_p=xHistory;
wHistory_p=wHistory;

cumulative_cost_siar = zeros(1, Duration);
cumulative_cost_sindy = zeros(1, Duration);
cumulative_cost_pinn = zeros(1, Duration);

tic
for ct = 1:Duration
    % Compute optimal control moves.
    [w_opt_siar, ~, info] = nlmpcmove(nlobjs{1}, xHistory(:,end), wHistory(:,end), x_ref);
    [w_opt_sindy, ~, info] = nlmpcmove(nlobjs{2}, xHistory_s(:,end), wHistory_s(:,end), x_ref);
    [w_opt_pinn, ~, info] = nlmpcmove(nlobjs{3}, xHistory_p(:,end), wHistory_p(:,end), x_ref);
    % Implement first optimal control move and update plant states.
    tspan = [0:deltat_real:T_real];
    % tspan = [0:Ts/Tc:Ts];
    [t,x] = ode45(@(t,x)replicator_sh(x,w_opt_siar,A0),tspan,xHistory(:,end));
    [t_s,x_s] = ode45(@(t,x)replicator_sh(x,w_opt_sindy,A0),tspan,xHistory_s(:,end));
    [t_p,x_p] = ode45(@(t,x)replicator_sh(x,w_opt_pinn,A0),tspan,xHistory_p(:,end));
 
    % Save plant states for display.
    xHistory = [xHistory x(2:end,:)'];
    inputHistory=[inputHistory x(end,:)'];
    wHistory = [wHistory kron(w_opt_siar,ones(1,T_real/deltat_real))];
    tHistory=[tHistory(1:end-1) [0+tHistory(end):deltat_real:T_real+tHistory(end)]];
    
    xHistory_s=[xHistory_s x_s(2:end,:)'];
    wHistory_s=[wHistory_s kron(w_opt_sindy,ones(1,T_real/deltat_real))];
    xHistory_p=[xHistory_p x_p(2:end,:)'];
    wHistory_p=[wHistory_p kron(w_opt_pinn,ones(1,T_real/deltat_real))];
    
    % Update cumulative costs
    if ct == 1
        cumulative_cost_siar(ct) = sum(w_opt_siar(:));
        cumulative_cost_sindy(ct) = sum(w_opt_sindy(:));
        cumulative_cost_pinn(ct) = sum(w_opt_pinn(:));
    else
        cumulative_cost_siar(ct) = cumulative_cost_siar(ct - 1) + sum(w_opt_siar(:));
        cumulative_cost_sindy(ct) = cumulative_cost_sindy(ct - 1) + sum(w_opt_sindy(:));
        cumulative_cost_pinn(ct) = cumulative_cost_pinn(ct - 1) + sum(w_opt_pinn(:));
    end
    ct
end
toc
disp('MPC is completed');

%% Plots
t_ctrl = T_real/deltat_real*Duration + t_valid;
time_vector = (1:Duration) * T_real; % Assuming T_real is the sample time in seconds

% Define background colors
color1 = [0.8, 0.8, 0.8];  % Darker gray
color2 = [0.9, 0.9, 0.9];  % Gray

figure(2)
figure_width = 1200;
figure_height = 400;
set(gcf, 'Position', [100, 100, figure_width, figure_height]);
set(gcf, 'Units', 'normalized');

% states
subplot(2, 3, 1, 'Position', [0.05 0.58 0.28 0.4])
p1=patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [0 0 1 1], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
p2=patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [0 0 1 1], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
p3=patch(NaN, NaN, 'w', 'EdgeColor', 'k','HandleVisibility', 'off');
hold on;

plot(tHistory,xHistory_s(1,:),'Color',[0 0.4470 0.7410],'Linewidth',3); 
plot(tHistory,xHistory_s(2,:),'Color',[0.9290 0.6940 0.1250],'Linewidth',3); 
plot(tHistory(t_datagen:length(xHistory_sindy(1,:))),xHistory_sindy(1,t_datagen:end),'--r','Linewidth',2); 
plot(tHistory(t_datagen:length(xHistory_sindy(2,:))),xHistory_sindy(2,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off');
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off');
yline(x_ref(1),'--','LineWidth',2,'HandleVisibility', 'off');
if x_ref(1) ~= x_ref(2)
yline(x_ref(2),'--','LineWidth',2);
end
xlim([0 tHistory(t_ctrl)])
ylim([0 1])
set(gca, 'FontSize', 14);
ylabel('strategies','Interpreter','latex','FontSize',18)

text(5, 0.5, '\bf SINDYc','FontSize',13, 'HorizontalAlignment', 'center','Interpreter','latex')
hold off;

% input
subplot(2, 3, 4, 'Position', [0.05 0.1 0.28 0.4]); 
hold on;
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [lb_c lb_c ub_c ub_c], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [lb_c lb_c ub_c ub_c], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
 
stairs(tHistory,wHistory_s(1,:),'Color',"#4DBEEE",'Linewidth',2);
stairs(tHistory,wHistory_s(2,:),'Color',"#1AA640",'Linewidth',2);
stairs(tHistory,wHistory_s(3,:),'Color',"#E68000",'Linewidth',2);

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off')
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off')
ylim([lb_c ub_c])
xlim([0 tHistory(t_ctrl)])
set(gca, 'FontSize', 14);
xlabel('time','Interpreter','latex','FontSize',18)
ylabel('inputs','Interpreter','latex','FontSize',18)
hold off;

% states
subplot(2, 3, 2, 'Position', [0.37 0.58 0.28 0.4])
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [0 0 1 1], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [0 0 1 1], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
hold on;

plot(tHistory,xHistory_p(1,:),'Color',[0 0.4470 0.7410],'Linewidth',3); 
plot(tHistory,xHistory_p(2,:),'Color',[0.9290 0.6940 0.1250],'Linewidth',3); 
plot(tHistory(t_datagen:length(xHistory_pinn(1,:))),xHistory_pinn(1,t_datagen:end),'--r','Linewidth',2); 
plot(tHistory(t_datagen:length(xHistory_pinn(2,:))),xHistory_pinn(2,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');
 
xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off');
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off');
yline(x_ref(1),'--','LineWidth',2);
if x_ref(1) ~= x_ref(2)
yline(x_ref(2),'--','LineWidth',2);
end
xlim([0 tHistory(t_ctrl)])
ylim([0 1])

text(5, 0.5, '\bf PINN','FontSize',13, 'HorizontalAlignment', 'center','Interpreter','latex')
set(gca, 'FontSize', 14);

hold off;

% input
subplot(2, 3, 5, 'Position', [0.37 0.1 0.28 0.4]);
hold on;
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [lb_c lb_c ub_c ub_c], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [lb_c lb_c ub_c ub_c], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
 
stairs(tHistory,wHistory_p(1,:),'Color',"#4DBEEE",'Linewidth',2);
stairs(tHistory,wHistory_p(2,:),'Color',"#1AA640",'Linewidth',2);
stairs(tHistory,wHistory_p(3,:),'Color',"#E68000",'Linewidth',2);

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off')
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off')
ylim([lb_c ub_c])
xlim([0 tHistory(t_ctrl)])
set(gca, 'FontSize', 14);
xlabel('time','Interpreter','latex','FontSize',18)

hold off;

% states
subplot(2, 3, 3, 'Position', [0.69 0.58 0.28 0.4])
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [0 0 1 1], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [0 0 1 1], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
hold on;

l1=plot(tHistory,xHistory(1,:),'Color',[0 0.4470 0.7410],'Linewidth',3); 
l2=plot(tHistory,xHistory(2,:),'Color',[0.9290 0.6940 0.1250],'Linewidth',3); 
l3=plot(tHistory(t_datagen:length(xHistory_siar(1,:))),xHistory_siar(1,t_datagen:end),'--r','Linewidth',2); 
plot(tHistory(t_datagen:length(xHistory_siar(2,:))),xHistory_siar(2,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off');
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off');
l7=yline(x_ref(1),'--','LineWidth',2);
if x_ref(1) ~= x_ref(2)
yline(x_ref(2),'--','LineWidth',2);
end
xlim([0 tHistory(t_ctrl)])
ylim([0 1])

text(6, 0.5, '\bf SIARc','FontSize',13, 'HorizontalAlignment', 'center','Interpreter','latex')
set(gca, 'FontSize', 14);
hold off;
 
% input
subplot(2, 3, 6, 'Position', [0.69 0.1 0.28 0.4]);
hold on;
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [lb_c lb_c ub_c ub_c], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [lb_c lb_c ub_c ub_c], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
 
l4=stairs(tHistory,wHistory(1,:),'Color',"#4DBEEE",'Linewidth',2);
l5=stairs(tHistory,wHistory(2,:),'Color',"#1AA640",'Linewidth',2);
l6=stairs(tHistory,wHistory(3,:),'Color',"#E68000",'Linewidth',2);

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off')
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off')
ylim([lb_c ub_c])
xlim([0 tHistory(t_ctrl)])
set(gca, 'FontSize', 14);
xlabel('time','Interpreter','latex','FontSize',18)

% General Legend Creation
hL = legend([l1,l2,l3,l7,l4,l5,l6,p1,p2,p3],{'$x_{1,1}$','$x_{2,1}$','estimated','equilibrium','$w_{1,1}$','$w_{1,2}$','$w_{1,3}$','Training', 'Validation', 'Control'},'Location','best','Orientation','vertical', 'NumColumns', 2,'Interpreter','latex','FontSize',15);
newPosition = [0.23 0.37 0.0 0.0];
newUnits = 'normalized';
set(hL,'Position', newPosition,'Units', newUnits);

print(gcf, 'sh.eps', '-depsc', '-r300');
savefig(gcf, 'sh.fig');