clear all
clc
close all
rng("default")

load('data_JULIA_rps.mat')

T_data = 1; %set according to julia data "t_train_data = collect(LinRange(0.0, 1.0, 11))"
deltat_data = 0.01; %set according to julia data "t_all_data = collect(LinRange(0.0, 1.0, 101))"

polynomials_siar = generate_polynomials_siar_rps(all_siar_coeffs, all_siar_terms);
polynomials_sindy = generate_polynomials_sindy_rps(all_sindy_coeffs, all_sindy_terms);
modelfile = "model_rps.onnx";
params = importONNXFunction(modelfile,"PINNFcn_rps");

% Extract weights and biases
nn_w1 = extractdata(params.Learnables.fc1_weight);
nn_b1 = extractdata(params.Learnables.fc1_bias);
nn_w2 = extractdata(params.Learnables.fc2_weight);
nn_b2 = extractdata(params.Learnables.fc2_bias);
nn_w3 = extractdata(params.Learnables.output_weight);
nn_b3 = extractdata(params.Learnables.output_bias);
%% initial conditions 
epsilon=-0.25;
A0=-[epsilon  1 -1;
     -1  epsilon  1;
    1 -1  epsilon];

n_state=4;
n_input=4;

Ts=0.1;
Tc =5;
Tp = 5;
lb_s = 0;
ub_s = 1;
lb_c = -1;
ub_c = 1;
weight_s = 1;
weight_c = 0.01;
weight_dc = 0.1;

%% generate the polynomials using Julia data
methods = {'siar', 'sindy', 'pinn'};
polynomials = {polynomials_siar, polynomials_sindy};

my_ests = cell(1,3);
my_Jacobians = cell(1,3);
for idx = 1:3
    
    if idx <= 2
    f_fcn = matlabFunction(polynomials{idx}, 'Vars', {'v1','v2','v3','v4','w1','w2','w3','w4'});
    my_ests{idx} = @(v,w)estimated_rps(v,w, f_fcn);
    % construct Jacobians
    [A_func, Bmv_func] = generateJacobian_rps(polynomials{idx}, n_state, n_input);
    my_Jacobians{idx} = @(v,w)computeJacobian_rps(v,w,A_func,Bmv_func);
    
    else
    my_ests{idx}=@(x,u)my_est_pinn_rps(x,u,params);
    my_Jacobians{idx}=@(x,u)myStateJacobian_PINN(x,u, nn_w1,nn_w2,nn_w3,nn_b1,nn_b2,nn_b3);
    end
end
disp('system models are imported');

%% construct nlmpc for both methods
nlobjs = cell(1, 3);
for idx = 1:3
    method = methods{idx};
    my_est = my_ests{idx};
    my_Jacobian = my_Jacobians{idx};

    % Initialize the MPC object
    nlobj = initializeMPC(my_est, my_Jacobian, n_state, n_input, Ts, Tc, Tp, lb_s, ub_s, lb_c, ub_c, weight_s, weight_c, weight_dc);
    nlobjs{idx} = nlobj;

    % Validate MPC
    v0 = [0.4;0.3;0.35;0.3];
    w0 = 0*ones(1,n_input)';
    validateFcns(nlobj, v0, w0);

end
 
disp('nlmpc is constructed');
%% training: data generation
x1_data_all = cellfun(@(c) c(1), x_all_data(1:101));
x2_data_all = cellfun(@(c) c(2), x_all_data(1:101));
y1_data_all = cellfun(@(c) c(3), x_all_data(1:101));
y2_data_all = cellfun(@(c) c(4), x_all_data(1:101));
 
% Extract x_data from each struct in data and scatter
x1_data = cellfun(@(c) c(1), x_train_data(1:11));
x2_data = cellfun(@(c) c(2), x_train_data(1:11));
y1_data = cellfun(@(c) c(3), x_train_data(1:11));
y2_data = cellfun(@(c) c(4), x_train_data(1:11)); 
 
wHistory = [];
for i=1:size(w_all_data(1:101),1)
wHistory = [wHistory w_all_data{i}];
end
 
tHistory=[0:deltat_data:T_data];
 
vHistory = [x1_data_all';x2_data_all';y1_data_all';y2_data_all'];
vHistory_siar = [vHistory(:,1:end-1)*0 vHistory(:,end)];
vHistory_sindy = [vHistory(:,1:end-1)*0 vHistory(:,end)];
vHistory_pinn = [vHistory(:,1:end-1)*0 vHistory(:,end)];

disp('data is exported');
%% compare xdot true, siar and sindy
xdotHistory = cell(1,4);
History_w = [];
History_t = [];
History_v = [vHistory(:,end)];
tspan = [0:0.01:0.1];

for ct = 1:50
w0 =2*ub_c*rand(n_input,1)-1;

[t,v] = ode45(@(t,v)replicator_rps(v,w0,A0),tspan,History_v(:,end));   
x_dot_values = cell(1,3); 
    for idx = 1:4
        if idx == 4
            for i = 1:length(t)
            x_dot_values{idx}(:, i) = replicator_rps(v(i, :)', w0, A0);
            end
        else
            for i = 1:length(t)
                x_dot_values{idx}(:, i) = my_ests{idx}(v(i, :)', w0);
            end 
        end
    end
    
History_v = [History_v v(2:end,:)'];

    if isempty(History_w) 
        History_w = w0*ones(1,length(t));
    else
        History_w=[History_w w0*ones(1,length(t(2:end)))];
    end
    
    if isempty(History_t) 
        History_t = t';
    else
        History_t = [History_t History_t(end) + t(2:end)'];
    end
    
    for idx = 1:4
    xdotHistory{idx}=[xdotHistory{idx}(:,1:end-1) x_dot_values{idx}];
    end
input(:,ct)=w0;    
end

figure(1)
xdoterror_siar = abs(xdotHistory{1, 1}(1:2,:)-xdotHistory{1, 4}(1:2,:));
xdoterror_sindy= abs(xdotHistory{1, 2}(1:2,:)-xdotHistory{1, 4}(1:2,:));
xdoterror_pinn= abs(xdotHistory{1, 3}(1:2,:)-xdotHistory{1, 4}(1:2,:));

colors = [0 0.4470 0.7410; 0.4660 0.6740 0.1880; 0.9290 0.6940 0.1250; 0.4940 0.1840 0.5560];

subplot(4,1,1)
hold on;
for k = 1:2
    l1(k) = plot(History_t, xdoterror_siar(k,:), 'Color', colors(k,:), 'Linewidth', 2);
end
ylim([0 1])
xlabel('time','Interpreter','latex','FontSize',12)
ylabel('error','Interpreter','latex','FontSize',12)
title('\textbf{Estimation Error in $\dot{x}$}','Interpreter','latex','FontSize',12)
text(0.5, 0.5, '\bf SIAR','FontSize',12, 'HorizontalAlignment', 'center','Interpreter','latex')

subplot(4,1,2)
hold on;
for k = 1:2
    l2(k) = plot(History_t, xdoterror_sindy(k,:), 'Color', colors(k,:), 'Linewidth', 2,'HandleVisibility', 'off');
end
xlabel('time','Interpreter','latex','FontSize',12)
ylabel('error','Interpreter','latex','FontSize',12)
text(0.5, 100, '\bf SINDYc','FontSize',12, 'HorizontalAlignment', 'center','Interpreter','latex')

subplot(4,1,3)
hold on;
for k = 1:2
    l3(k) = plot(History_t, xdoterror_pinn(k,:), 'Color', colors(k,:), 'Linewidth', 2,'HandleVisibility', 'off');
end
ylim([0 1])
xlabel('time','Interpreter','latex','FontSize',12)
ylabel('error','Interpreter','latex','FontSize',12)
text(0.5, 0.8, '\bf PINN','FontSize',12, 'HorizontalAlignment', 'center','Interpreter','latex')

subplot(4,1,4)
hold on;

colors = {'#4DBEEE','#1AA640','#E68000','#8040E6'};
for i=1:n_input
l3(i)=stairs(History_t,History_w(i,:),'Linewidth',2,'Color', colors{i});
end
xlabel('time','Interpreter','latex','FontSize',12)
ylabel('inputs','Interpreter','latex','FontSize',12)

%% validation of the estimated dynamics
duration_ctrl = 0.2;
forplot=0;
while tHistory(end)<=3
w0 = 2*ub_c*rand(n_input,1)-1;
    if forplot == 0
       [t_sindy,v_sindy] = ode45(@(t_sindy,v_sindy)my_ests{2}(v_sindy,w0),[0 duration_ctrl],vHistory_sindy(:,end));
        if any(diff(t_sindy) < 1e-10) 
           forplot =1;
        end
    else 
        t_sindy=[0:0.0025:duration_ctrl]';
        v_sindy=( vHistory_sindy(:,end).*ones(4,length(t_sindy)))';
    end

[t_siar,v_siar] = ode45(@(t_siar,v_siar)my_ests{1}(v_siar,w0),t_sindy,vHistory_siar(:,end));%%!!
[t_pinn,v_pinn] = ode45(@(t_pinn,v_pinn)my_ests{3}(v_pinn,w0),t_sindy,vHistory_pinn(:,end));%%!!

[t,v] = ode45(@(t,v)replicator_rps(v,w0,A0),t_sindy,vHistory(:,end));
 
vHistory_siar = [vHistory_siar v_siar(2:end,:)'];
vHistory_sindy = [vHistory_sindy v_sindy(2:end,:)'];
vHistory_pinn = [vHistory_pinn v_pinn(2:end,:)'];

vHistory = [vHistory v(2:end,:)'];
 
 
wHistory=[wHistory w0*ones(1,length(t_sindy)-1)];
tHistory=[tHistory(1:end-1) (tHistory(end)+t_sindy)'];
inputHistory=vHistory(:,end);
end
t_datagen = size(x_all_data(1:101),1);
t_valid = length(tHistory)-101 + t_datagen;

disp('validation is completed');

%% initialize the MPC
v_ref=[1/3 1/3 1/3 1/3];
Duration = 120;
T_real=0.1;
deltat_real=Ts/10;

vHistory_s=vHistory;
wHistory_s=wHistory;
vHistory_p=vHistory;
wHistory_p=wHistory;

tic
for ct = 1:Duration
    % Compute optimal control moves.
    [w_opt_siar, ~, info] = nlmpcmove(nlobjs{1}, vHistory(:,end), wHistory(:,end), v_ref);
    [w_opt_sindy, ~, info] = nlmpcmove(nlobjs{2}, vHistory_s(:,end), wHistory_s(:,end), v_ref);
    [w_opt_pinn, ~, info] = nlmpcmove(nlobjs{3}, vHistory_p(:,end), wHistory_p(:,end), v_ref);
    % Implement first optimal control move and update plant states.
    tspan = [0:deltat_real:T_real];
    [t,v] = ode45(@(t,v)replicator_rps(v,w_opt_siar,A0),tspan,vHistory(:,end));
    [t_s,v_s] = ode45(@(t,v)replicator_rps(v,w_opt_sindy,A0),tspan,vHistory_s(:,end));
    [t_p,v_p] = ode45(@(t,v)replicator_rps(v,w_opt_pinn,A0),tspan,vHistory_p(:,end));
 
    % Save plant states for display.
    vHistory = [vHistory v(2:end,:)'];
    inputHistory=[inputHistory v(end,:)'];
    wHistory = [wHistory kron(w_opt_siar,ones(1,T_real/deltat_real))];
    tHistory=[tHistory(1:end-1) [0+tHistory(end):deltat_real:T_real+tHistory(end)]];
    
    vHistory_s=[vHistory_s v_s(2:end,:)'];
    wHistory_s=[wHistory_s kron(w_opt_sindy,ones(1,T_real/deltat_real))];
    vHistory_p=[vHistory_p v_p(2:end,:)'];
    wHistory_p=[wHistory_p kron(w_opt_pinn,ones(1,T_real/deltat_real))];
    ct
end
toc

disp('MPC is completed');

%% Plots
t_ctrl = T_real/deltat_real*Duration + t_valid;

% comparison without sindy
% Define background colors
color1 = [0.8, 0.8, 0.8];  % Darker gray
color2 = [0.9, 0.9, 0.9];  % Gray

figure(2)
figure_width = 1200;
figure_height = 600;
set(gcf, 'Position', [100, 100, figure_width, figure_height]);
set(gcf, 'Units', 'normalized');

% states
subplot(2, 2, 1, 'Position', [0.05 0.58 0.43 0.4])
p1=patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [0 0 1 1], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
p2=patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [0 0 1 1], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
p3=patch(NaN, NaN, 'w', 'EdgeColor', 'k','HandleVisibility', 'off');
hold on;

plot(tHistory,vHistory_p(1,:),'Color',[0 0.4470 0.7410],'Linewidth',3); 
plot(tHistory,vHistory_p(2,:),'Color',[0.4660 0.6740 0.1880],'Linewidth',3); 
plot(tHistory,vHistory_p(3,:),'Color',[0.9290 0.6940 0.1250],'Linewidth',3); 
plot(tHistory,vHistory_p(4,:),'Color',[0.4940 0.1840 0.5560],'Linewidth',3);
plot(tHistory(t_datagen:length(vHistory_pinn(1,:))),vHistory_pinn(1,t_datagen:end),'--r','Linewidth',2); 
plot(tHistory(t_datagen:length(vHistory_pinn(2,:))),vHistory_pinn(2,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');
plot(tHistory(t_datagen:length(vHistory_pinn(3,:))),vHistory_pinn(3,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off'); 
plot(tHistory(t_datagen:length(vHistory_pinn(4,:))),vHistory_pinn(4,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off');
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off');
yline(v_ref(1),'--','LineWidth',2);
if v_ref(1) ~= v_ref(2)
yline(v_ref(2),'--','LineWidth',2);
end
xlim([0 tHistory(t_ctrl)])
ylim([0 1])
set(gca, 'FontSize', 14);
ylabel('strategies','Interpreter','latex','FontSize',18)

text(6, 1-0.1, '\bf PINN','FontSize',13, 'HorizontalAlignment', 'center','Interpreter','latex')
hold off;

% input
subplot(2, 2, 3, 'Position', [0.05 0.1 0.43 0.4]); 
hold on;
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [lb_c lb_c ub_c ub_c], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [lb_c lb_c ub_c ub_c], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
 
stairs(tHistory,wHistory_p(1,:),'Color',"#4DBEEE",'Linewidth',2);
stairs(tHistory,wHistory_p(2,:),'Color',"#1AA640",'Linewidth',2);
stairs(tHistory,wHistory_p(3,:),'Color',"#E68000",'Linewidth',2);
stairs(tHistory,wHistory_p(4,:),'Color',"#8040E6",'Linewidth',2);

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off')
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off')
ylim([lb_c ub_c])
xlim([0 tHistory(t_ctrl)])
set(gca, 'FontSize', 14);

xlabel('time','Interpreter','latex','FontSize',18)
ylabel('inputs','Interpreter','latex','FontSize',18)
hold off;

% figure(4)
% states
subplot(2, 2, 2, 'Position', [0.52 0.58 0.43 0.4])
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [0 0 1 1], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [0 0 1 1], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
hold on;

l1=plot(tHistory,vHistory(1,:),'Color',[0 0.4470 0.7410],'Linewidth',3); 
l2=plot(tHistory,vHistory(2,:),'Color',[0.4660 0.6740 0.1880],'Linewidth',3); 
l3=plot(tHistory,vHistory(3,:),'Color',[0.9290 0.6940 0.1250],'Linewidth',3); 
l4=plot(tHistory,vHistory(4,:),'Color',[0.4940 0.1840 0.5560],'Linewidth',3);
l5=plot(tHistory(t_datagen:length(vHistory_siar(1,:))),vHistory_siar(1,t_datagen:end),'--r','Linewidth',2); 
plot(tHistory(t_datagen:length(vHistory_siar(2,:))),vHistory_siar(2,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');
plot(tHistory(t_datagen:length(vHistory_siar(3,:))),vHistory_siar(3,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off'); 
plot(tHistory(t_datagen:length(vHistory_siar(4,:))),vHistory_siar(4,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off');
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off');
l6=yline(v_ref(1),'--','LineWidth',2);
if v_ref(1) ~= v_ref(2)
yline(v_ref(2),'--','LineWidth',2);
end
xlim([0 tHistory(t_ctrl)])
ylim([0 1])

set(gca, 'FontSize', 14);
text(6, 1-0.1, '\bf SIARc','FontSize',13, 'HorizontalAlignment', 'center','Interpreter','latex')
hold off;
 
% input
subplot(2, 2, 4, 'Position', [0.52 0.1 0.43 0.4]);
hold on;
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [lb_c lb_c ub_c ub_c], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [lb_c lb_c ub_c ub_c], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
 
l7=stairs(tHistory,wHistory(1,:),'Color',"#4DBEEE",'Linewidth',2);
l8=stairs(tHistory,wHistory(2,:),'Color',"#1AA640",'Linewidth',2);
l9=stairs(tHistory,wHistory(3,:),'Color',"#E68000",'Linewidth',2);
l10=stairs(tHistory,wHistory(4,:),'Color',"#8040E6",'Linewidth',2);

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off')
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off')
ylim([lb_c ub_c])
xlim([0 tHistory(t_ctrl)])
set(gca, 'FontSize', 14);

xlabel('time','Interpreter','latex','FontSize',18)

% General Legend Creation
hL = legend([l1,l2,l3,l4,l5,l6,l7,l8,l9,l10,p1,p2,p3],{'$x_{1,1}$','$x_{1,2}$','$x_{2,1}$','$x_{2,2}$','estimated','equilibrium','$w_{1,1}$','$w_{1,2}$','$w_{1,3}$','$w_{1,4}$','Training', 'Validation', 'Control'},'Location','northoutside','Orientation','horizontal','NumColumns', 2,'Interpreter','latex','FontSize',13);
newPosition = [0.86 0.88 0.0 0.0];
newUnits = 'normalized';
set(hL,'Position', newPosition,'Units', newUnits);

print(gcf, 'RPS.eps', '-depsc', '-r300');
savefig(gcf, 'RPS.fig');
%% comparison
% Define background colors
color1 = [0.8, 0.8, 0.8];  % Darker gray
color2 = [0.9, 0.9, 0.9];  % Gray

figure(3)
figure_width = 1300;
figure_height = 600;
set(gcf, 'Position', [100, 100, figure_width, figure_height]);
set(gcf, 'Units', 'normalized');

% states
subplot(2, 3, 1, 'Position', [0.05 0.58 0.28 0.4])
p1=patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [0 0 1 1], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
p2=patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [0 0 1 1], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
p3=patch(NaN, NaN, 'w', 'EdgeColor', 'k','HandleVisibility', 'off');
hold on;

l1=plot(tHistory,vHistory_s(1,:),'Color',[0 0.4470 0.7410],'Linewidth',3); 
l2=plot(tHistory,vHistory_s(2,:),'Color',[0.4660 0.6740 0.1880],'Linewidth',3); 
l3=plot(tHistory,vHistory_s(3,:),'Color',[0.9290 0.6940 0.1250],'Linewidth',3); 
l4=plot(tHistory,vHistory_s(4,:),'Color',[0.4940 0.1840 0.5560],'Linewidth',3);
l5=plot(tHistory(t_datagen:length(vHistory_sindy(1,:))),vHistory_sindy(1,t_datagen:end),'--r','Linewidth',2); 
plot(tHistory(t_datagen:length(vHistory_sindy(2,:))),vHistory_sindy(2,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');
plot(tHistory(t_datagen:length(vHistory_sindy(3,:))),vHistory_sindy(3,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off'); 
plot(tHistory(t_datagen:length(vHistory_sindy(4,:))),vHistory_sindy(4,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off');
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off');
l6=yline(v_ref(1),'--','LineWidth',2,'HandleVisibility', 'off');
if v_ref(1) ~= v_ref(2)
yline(v_ref(2),'--','LineWidth',2);
end
xlim([0 tHistory(t_ctrl)])
ylim([0 1])
ylabel('strategies','Interpreter','latex','FontSize',18)

text(8, 1-0.2, '\bf SINDYc','FontSize',13, 'HorizontalAlignment', 'center','Interpreter','latex')
hold off;

% input
subplot(2, 3, 4, 'Position', [0.05 0.1 0.28 0.4]); 
hold on;
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [lb_c lb_c ub_c ub_c], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [lb_c lb_c ub_c ub_c], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
 
l7=stairs(tHistory,wHistory_s(1,:),'Color',"#4DBEEE",'Linewidth',2);
l8=stairs(tHistory,wHistory_s(2,:),'Color',"#1AA640",'Linewidth',2);
l9=stairs(tHistory,wHistory_s(3,:),'Color',"#E68000",'Linewidth',2);
l10=stairs(tHistory,wHistory_s(4,:),'Color',"#8040E6",'Linewidth',2);

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off')
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off')
ylim([lb_c ub_c])
xlim([0 tHistory(t_ctrl)])
xlabel('time','Interpreter','latex','FontSize',18)
ylabel('inputs','Interpreter','latex','FontSize',18)
hold off;

% states
subplot(2, 3, 2, 'Position', [0.37 0.58 0.28 0.4])
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [0 0 1 1], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [0 0 1 1], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
hold on;

plot(tHistory,vHistory_p(1,:),'Color',[0 0.4470 0.7410],'Linewidth',3); 
plot(tHistory,vHistory_p(2,:),'Color',[0.4660 0.6740 0.1880],'Linewidth',3); 
plot(tHistory,vHistory_p(3,:),'Color',[0.9290 0.6940 0.1250],'Linewidth',3); 
plot(tHistory,vHistory_p(4,:),'Color',[0.4940 0.1840 0.5560],'Linewidth',3);
plot(tHistory(t_datagen:length(vHistory_pinn(1,:))),vHistory_pinn(1,t_datagen:end),'--r','Linewidth',2); 
plot(tHistory(t_datagen:length(vHistory_pinn(2,:))),vHistory_pinn(2,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');
plot(tHistory(t_datagen:length(vHistory_pinn(3,:))),vHistory_pinn(3,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off'); 
plot(tHistory(t_datagen:length(vHistory_pinn(4,:))),vHistory_pinn(4,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off');
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off');
yline(v_ref(1),'--','LineWidth',2);
if v_ref(1) ~= v_ref(2)
yline(v_ref(2),'--','LineWidth',2);
end
xlim([0 tHistory(t_ctrl)])
ylim([0 1])

text(8, 1-0.2, '\bf PINN','FontSize',13, 'HorizontalAlignment', 'center','Interpreter','latex')
hold off;

% input
subplot(2, 3, 5, 'Position', [0.37 0.1 0.28 0.4]);
hold on;
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [lb_c lb_c ub_c ub_c], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [lb_c lb_c ub_c ub_c], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
 
stairs(tHistory,wHistory_p(1,:),'Color',"#4DBEEE",'Linewidth',2);
stairs(tHistory,wHistory_p(2,:),'Color',"#1AA640",'Linewidth',2);
stairs(tHistory,wHistory_p(3,:),'Color',"#E68000",'Linewidth',2);
stairs(tHistory,wHistory_p(4,:),'Color',"#8040E6",'Linewidth',2);

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off')
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off')
ylim([lb_c ub_c])
xlim([0 tHistory(t_ctrl)])
xlabel('time','Interpreter','latex','FontSize',18)
hold off;

% figure(4)
% states
subplot(2, 3, 3, 'Position', [0.69 0.58 0.28 0.4])
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [0 0 1 1], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [0 0 1 1], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
hold on;

plot(tHistory,vHistory(1,:),'Color',[0 0.4470 0.7410],'Linewidth',3); 
plot(tHistory,vHistory(2,:),'Color',[0.4660 0.6740 0.1880],'Linewidth',3); 
plot(tHistory,vHistory(3,:),'Color',[0.9290 0.6940 0.1250],'Linewidth',3); 
plot(tHistory,vHistory(4,:),'Color',[0.4940 0.1840 0.5560],'Linewidth',3);
plot(tHistory(t_datagen:length(vHistory_siar(1,:))),vHistory_siar(1,t_datagen:end),'--r','Linewidth',2); 
plot(tHistory(t_datagen:length(vHistory_siar(2,:))),vHistory_siar(2,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');
plot(tHistory(t_datagen:length(vHistory_siar(3,:))),vHistory_siar(3,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off'); 
plot(tHistory(t_datagen:length(vHistory_siar(4,:))),vHistory_siar(4,t_datagen:end),'--r','Linewidth',2,'HandleVisibility', 'off');

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off');
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off');
yline(v_ref(1),'--','LineWidth',2);
if v_ref(1) ~= v_ref(2)
yline(v_ref(2),'--','LineWidth',2);
end
xlim([0 tHistory(t_ctrl)])
ylim([0 1])
 
text(8, 1-0.2, '\bf SIARc','FontSize',13, 'HorizontalAlignment', 'center','Interpreter','latex')
hold off;
 
% input
subplot(2, 3, 6, 'Position', [0.69 0.1 0.28 0.4]);
hold on;
patch([0 tHistory(t_datagen) tHistory(t_datagen) 0], [lb_c lb_c ub_c ub_c], color1, 'EdgeColor', 'none','HandleVisibility', 'off');
patch([tHistory(t_datagen) tHistory(t_valid) tHistory(t_valid) tHistory(t_datagen)], [lb_c lb_c ub_c ub_c], color2, 'EdgeColor', 'none','HandleVisibility', 'off');
 
stairs(tHistory,wHistory(1,:),'Color',"#4DBEEE",'Linewidth',2);
stairs(tHistory,wHistory(2,:),'Color',"#1AA640",'Linewidth',2);
stairs(tHistory,wHistory(3,:),'Color',"#E68000",'Linewidth',2);
stairs(tHistory,wHistory(4,:),'Color',"#8040E6",'Linewidth',2);

xline(tHistory(t_datagen),':','LineWidth',2,'HandleVisibility', 'off')
xline(tHistory(t_valid),':','LineWidth',2,'HandleVisibility', 'off')
ylim([lb_c ub_c])
xlim([0 tHistory(t_ctrl)])
xlabel('time','Interpreter','latex','FontSize',18)

% General Legend Creation
hL = legend([l1,l2,l3,l4,l5,l6,l7,l8,l9,l10,p1,p2,p3],{'$x_{1,1}$','$x_{1,2}$','$x_{2,1}$','$x_{2,2}$','estimated','equilibirum','$w_{1,1}$','$w_{1,2}$','$w_{1,3}$','$w_{1,4}$','Training', 'Validation', 'Control'},'Location','northoutside','Orientation','horizontal','Interpreter','latex','FontSize',13);
newPosition = [0.5 0.03 0.0 0.0];
newUnits = 'normalized';
set(hL,'Position', newPosition,'Units', newUnits);