%Comparison of LSTM model of cascaded tanks data using five fold
%verification.

%Load the data sets, set up vectors of inputs
S = load('Tank1.mat');
S2 = load('Tank2.mat');
u1 = S.u';
y1 = S.y';
u2 = S2.u';
y2 = S2.y';

h1_T1 = y1(:,1);
h1_T2 = y2(:,1);
h2_T1 = y1(:,2);
h2_T2 = y2(:,2);


%Create vectors of time values corresponding to datapoints
t1=1:1:length(u1);
t1=t1*5;
t2=1:1:length(u2);
t2 = t2*4;
t = [t1,t2]';

% Set up for creating LSTM NN
numChannels = 3; %Three inputs: previous values of h1,h2 and u
numResponses = 2; %We have one model for h1 and h2

layers = [
    sequenceInputLayer(numChannels)
    gruLayer(150)
    fullyConnectedLayer(numResponses)
    regressionLayer];

options = trainingOptions("adam", ...
    MaxEpochs=150, ...
    SequencePaddingDirection="left", ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Verbose=0);

%Training inputs (XTrain) are h1 h2 and u shifted back one step
%Training data (TTrain) are h1 and h2 at the current step
inputs_T1 = [h1_T1,h2_T1,u1];
inputs_T2 = [h1_T2,h2_T2,u2];

%Initialize measurement values
err_h1 = zeros(5,1);
err_h2 = zeros(5,1);
per_err_h1 = zeros(5,1);
per_err_h2 = zeros(5,1);

T_Train = zeros(5,1);
T_Predict = zeros(5,1);

%Divide the data up into the same 5 folds as in CasTanks_GP_RK4Int. The 2nd fold
%contains two curves, so it requires additional setup.
for i=1:5
    if i==1
        test_inputs = cell(1,1);
        test_h1 = h1_T1(1:2000);
        test_h2 = h2_T1(1:2000);
        test_inputs{1} = inputs_T1(1:2000,:);
        train_inputs = cell(1,2);
        train_inputs{1} = inputs_T1(2001:end,:);
        train_inputs{2} = inputs_T2;
        train_h = cell(1,2);
        train_h{1} = [h1_T1(2001:end),h2_T1(2001:end)];
        train_h{2} = [h1_T2,h2_T2];
    else
        if i==2
            test_inputs = cell(1,2);
            test_h1 = cell(2,1);
            test_h2 = cell(2,1);
            test_inputs{1} = inputs_T1(2001:end,:);
            test_inputs{2} = inputs_T2(1:1500,:);
            test_h1{1} = h1_T1(2001:end);
            test_h1{2} = h1_T2(1:1500);
            test_h2{1} = h2_T1(2001:end);
            test_h2{2} = h2_T2(1:1500);
            train_inputs = cell(1,2);
            train_h = cell(1,2);
            train_inputs{1} = inputs_T1(1:2000,:);
            train_inputs{2} = inputs_T2(1501:end,:);
            train_h{1} = [h1_T1(1:2000),h2_T1(1:2000)];
            train_h{2} = [h1_T2(1501:end),h2_T2(1501:end)];
        else
            if i==3
                test_inputs = cell(1,1);
                test_h1 = h1_T2(1501:3500);
                test_h2 = h2_T2(1501:3500);
                test_inputs{1} = inputs_T2(1501:3500,:);
                train_inputs = cell(1,3);
                train_h = cell(1,3);
                train_inputs{1} = inputs_T1;
                train_inputs{2} = inputs_T2(1:1500,:);
                train_inputs{3} = inputs_T2(3501:end,:);
                train_h{1} = [h1_T1,h2_T1];
                train_h{2} = [h1_T2(1:1500),h2_T2(1:1500)];
                train_h{3} = [h1_T2(3501:end),h2_T2(3501:end)];
            else
                if i==4
                    test_inputs = cell(1,1);
                    test_h1 = h1_T2(3501:5500);
                    test_h2 = h2_T2(3501:5500);
                    test_inputs{1} = inputs_T2(3501:5500,:);
                    train_inputs = cell(1,3);
                    train_h = cell(1,3);
                    train_inputs{1} = inputs_T1;
                    train_inputs{2} = inputs_T2(1:3500,:);
                    train_inputs{3} = inputs_T2(5501:end,:);
                    train_h{1} = [h1_T1,h2_T1];
                    train_h{2} = [h1_T2(1:3500),h2_T2(1:3500)];
                    train_h{3} = [h1_T2(5501:end),h2_T2(5501:end)];
                else
                    test_inputs = cell(1,1);
                    test_h1 = h1_T2(5501:end);
                    test_h2 = h2_T2(5501:end);
                    test_inputs{1} = inputs_T2(5501:end,:);
                    train_inputs = cell(1,2);
                    train_inputs{1} = inputs_T1;
                    train_inputs{2} = inputs_T2(1:5500,:);
                    train_h = cell(1,2);
                    train_h{1} = [h1_T1,h2_T1];
                    train_h{2} = [h1_T2(1:5500),h2_T2(1:5500)];
                end
            end
        end
    end

    %Shift inputs back one time step; remove the first time step from the
    %outputs
    if i==1 || i==5
    test_inputs{1}=test_inputs{1}(1:end-1,:)';
    train_inputs{1} = train_inputs{1}(1:end-1,:)';
    train_inputs{2} = train_inputs{2}(1:end-1,:)';
    train_h{1} = train_h{1}(2:end,:)';
    train_h{2} = train_h{2}(2:end,:)';
    test_h1 = test_h1(2:end)';
    test_h2 = test_h2(2:end)';
    else 
        if i==2
            test_inputs{1}=test_inputs{1}(1:end-1,:)';
            test_inputs{2}=test_inputs{2}(1:end-1,:)';
            train_inputs{1} = train_inputs{1}(1:end-1,:)';
            train_inputs{2} = train_inputs{2}(1:end-1,:)';
            train_h{1} = train_h{1}(2:end,:)';
            train_h{2} = train_h{2}(2:end,:)';
            test_h1{1} = test_h1{1}(2:end)';
            test_h1{2} = test_h1{2}(2:end)';
            test_h2{1} = test_h2{1}(2:end)';
            test_h2{2} = test_h2{2}(2:end)';
        else
            test_inputs{1}=test_inputs{1}(1:end-1,:)';
            train_inputs{1} = train_inputs{1}(1:end-1,:)';
            train_inputs{2} = train_inputs{2}(1:end-1,:)';
            train_inputs{3} = train_inputs{3}(1:end-1,:)';
            train_h{1} = train_h{1}(2:end,:)';
            train_h{2} = train_h{2}(2:end,:)';
            train_h{3} = train_h{3}(2:end,:)';
            test_h1 = test_h1(2:end)';
            test_h2 = test_h2(2:end)';
        end
    end

    %Normalize the data using z score normalization based upon the current
    %training set
    mu_in = mean(cat(2,train_inputs{:}),2);
    sigma_in = std(cat(2,train_inputs{:}),0,2);

    mu_h = mean(cat(2,train_h{:}),2);
    sigma_h = std(cat(2,train_h{:}),0,2);

    for n = 1:numel(train_inputs)
        train_inputs{n} = (train_inputs{n} - mu_in) ./ sigma_in;
        train_h{n} = (train_h{n} - mu_h) ./ sigma_h;
    end

    test_inputs{1} = (test_inputs{1} - mu_in) ./ sigma_in;
    if i==2
        test_inputs{2} = (test_inputs{2}-mu_in) ./ sigma_in;
    end

    %Train the network, and measure the time it takes
    T_TrainS = tic;
    h_net = trainNetwork(train_inputs,train_h,layers,options);
    T_Train(i) = toc(T_TrainS);

    %Make predictions on the test set (and measure the time required to 
    %make predictions).
    %We use a kind of hybrid approach between closed and open forecasting: u
    %(the third input) is known (and our forcing function) so its true
    %previous value is given to the model to make predictions. The other two
    %inputs (the previous values of h1 and h2) are the outputs of the
    %model, and predictions use the previous step's predictions as
    %its input
    h_net = resetState(h_net);
    offset = 50; %Use the 1st 50 dps of the test set to get the prediction started
    [h_net,Z] = predictAndUpdateState(h_net,test_inputs{1}(:,1:offset));
    if i~= 2
        numPredictionTimeSteps = 1999;
        Xt = [Z(:,end);test_inputs{1}(3,offset)];
        Y = zeros(numChannels,numPredictionTimeSteps);
        Y(3,:) =test_inputs{1}(3,:);
        
        T_PredictS = tic;
        for t = 1:numPredictionTimeSteps
            [h_net,Y(1:2,t)] = predictAndUpdateState(h_net,Xt);
            Xt = Y(:,t);
        end
        T_Predict(i) = toc(T_PredictS);
    else
        %When i=2 we have two test curves that need evaluated:
        numPredictionTimeSteps = 499;
        Xt = [Z(:,end);test_inputs{1}(3,offset)];
        Y = zeros(numChannels,numPredictionTimeSteps);
        Y(3,:) =test_inputs{1}(3,:);

        T_PredictS = tic;
        for t = 1:numPredictionTimeSteps
            [h_net,Y(1:2,t)] = predictAndUpdateState(h_net,Xt);
            Xt = Y(:,t);
        end
        h_net = resetState(h_net);
        [h_net,Z] = predictAndUpdateState(h_net,test_inputs{2}(:,1:offset));
        numPredictionTimeSteps = 1499;
        Xt = [Z(:,end);test_inputs{2}(3,offset)];
        Y2 = zeros(numChannels,numPredictionTimeSteps);
        Y2(3,:) =test_inputs{2}(3,:);
        for t = 1:numPredictionTimeSteps
            [h_net,Y2(1:2,t)] = predictAndUpdateState(h_net,Xt);
            Xt = Y2(:,t);
        end
        T_Predict(i) = toc(T_PredictS);
    end

    %Denormalizing the predictions for comparison to the test data
    Y = Y(1:2,:).*sigma_h + mu_h; 

    %Calculate MAE, MAPEs for each fold
    if i~=2
        err_h1(i) = mean(abs(Y(1,50:end)-test_h1(50:end)));
        err_h2(i) = mean(abs(Y(2,50:end)-test_h2(50:end)));
        per_err_h1(i) = mean(abs((Y(1,50:end)-test_h1(50:end))./test_h1(50:end)));
        per_err_h2(i) = mean(abs((Y(2,50:end)-test_h2(50:end))./test_h2(50:end)));
    else
        %Two test curves; 1st curve is 1/4 the total dps and the 2nd has the remaining 3/4 so we take
        %weighted average of MAE, MAPE
        Y2 = Y2(1:2,:).*sigma_h + mu_h; %Denormalize 2nd Curve For Comparison
        ec1h1 = mean(abs(Y(1,50:end)-test_h1{1}(50:end)));
        ec2h1 = mean(abs(Y2(1,50:end)-test_h1{2}(50:end)));
        ec1h2 = mean(abs(Y(2,50:end)-test_h2{1}(50:end)));
        ec2h2 = mean(abs(Y2(2,50:end)-test_h2{2}(50:end)));
        pec1h1 = mean(abs((Y(1,50:end)-test_h1{1}(50:end))./test_h1{1}(50:end)));
        pec2h1 = mean(abs((Y2(1,50:end)-test_h1{2}(50:end))./test_h1{2}(50:end)));
        pec1h2 = mean(abs((Y(2,50:end)-test_h2{1}(50:end))./test_h2{1}(50:end)));
        pec2h2 = mean(abs((Y2(2,50:end)-test_h2{2}(50:end))./test_h2{2}(50:end)));
        err_h1(i) = ec1h1*(1/4) + ec2h1*(3/4);
        err_h2(i) = ec1h2*(1/4) + ec2h2*(3/4);
        per_err_h1(i) = pec1h1*(1/4) + pec2h1*(3/4);
        per_err_h2(i) = pec1h2*(1/4) + pec2h2*(3/4);
    end
end

%Calculate the average MAE and MAPE values over all 5 folds
MAE_h1 = mean(err_h1)
MAE_h2 = mean(err_h2)
MAPE_h1 = mean(per_err_h1)
MAPE_h2 = mean(per_err_h2)

%Calculate the standard deviation of the MAE and MAPE values between folds
MAE_h1_stdev = (sum((err_h1-MAE_h1).^2)/5)^0.5
MAE_h2_stdev = (sum((err_h2-MAE_h2).^2)/5)^0.5
MAPE_h1_stdev = (sum((per_err_h1-MAPE_h1).^2)/5)^0.5
MAPE_h2_stdev = (sum((per_err_h2-MAPE_h2).^2)/5)^0.5

%Calculate the average time it takes to train the model and make
%predictions
Mean_Train_Time = mean(T_Train)
Mean_Predict_Time = mean(T_Predict)
