%Creates a model of the SIR data using a LSTM Neural Net using training
%data with constant b values, and evaluates that model using the test data
%with variable b values


[curves,inputs, pureinputs]=SIR_GP_Data_Train(0); %Train Data
[Testcurves,Testinputs, Testpureinputs]=SIR_GP_Data_Test(0); %Test Data

% Set up for creating LSTM NN

numChannels = 3; %Three inputs: previous values of I, R, and b
numResponses = 2; %We have one model for current values of I and R

layers = [
    sequenceInputLayer(numChannels)
    lstmLayer(128)
    fullyConnectedLayer(numResponses)
    regressionLayer];

options = trainingOptions("adam", ...
    MaxEpochs=125, ...
    SequencePaddingDirection="left", ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Verbose=0);

XTrain = cell(1,numel(inputs)); %Prep X,T for Cell creation
TTrain = cell(1,numel(inputs));

for i = 1:numel(inputs)
   
   XTrain{i} = inputs{i}(1:end-1,[2 4 5])'; %Our inputs: B I and R before
   TTrain{i} = inputs{i}(2:end,[4 5])'; %Our outputs: I andd R now

end

% Prepping train data

%Z-score normalization using only training data:
muX = mean(cat(2,XTrain{:}),2);
sigmaX = std(cat(2,XTrain{:}),0,2);

muT = mean(cat(2,TTrain{:}),2);
sigmaT = std(cat(2,TTrain{:}),0,2);
for n = 1:numel(XTrain)
    XTrain{n} = (XTrain{n} - muX) ./ sigmaX;
    TTrain{n} = (TTrain{n} - muT) ./ sigmaT;
end

%Create the model using the train data
T_TrainS = tic;
trainNet = trainNetwork(XTrain,TTrain,layers,options);
T_Train = toc(T_TrainS);

% Start of test set section
% Prepping test data

XTest = cell(1,numel(Testinputs)); %Prep X,T for Cell creation
TTest = cell(1,numel(Testinputs));

for i = 1:numel(Testinputs)
   
   XTest{i} = Testinputs{i}(1:end-1,[2 4 5])'; %Inputs: B I and R before
   TTest{i} = Testinputs{i}(2:end,[4 5])'; %Outputs: I and R now

end

%Normalize test data using the same values as for the training data
for n = 1:numel(XTest)
    XTest{n} = (XTest{n} - muX) ./ sigmaX;
    TTest{n} = (TTest{n} - muT) ./ sigmaT;
end

% Prediciton 

offset = 50; %Use the 1st 50 dps of each test set to get the prediction started
Y = cell(24,1); %Have 24 curves to test

for i=1:numel(XTest)
    h_net = resetState(trainNet);
    [h_net,Z] = predictAndUpdateState(h_net,XTest{i}(:,1:offset));

    numPredictionTimeSteps = 350;
    Xt = [XTest{i}(1,offset);Z(:,end)];
    Y{i} = zeros(numChannels,numPredictionTimeSteps);
    Y{i}(1,:) =XTest{i}(1,:);

    T_PredictS = tic;
    for t = 1:numPredictionTimeSteps
        [h_net,Y{i}(2:3,t)] = predictAndUpdateState(h_net,Xt);
        Xt = Y{i}(:,t);
    end
    T_Predict(i) = toc(T_PredictS);

end
%Finally compare both R and I

% Denormilize predictions for comparison
for i = 1:length(Y)
Y{i} = Y{i}(2:3,:).*sigmaT + muT; 

end

%Set up comparisons to the true values for our plots
for i = 1:length(Testpureinputs)
Testpureinputs2{i} =  Testpureinputs{i}(2:end,[4 5])';
end

Mean_T_Predict = mean(T_Predict) % average time in sec of Predictions
Mean_T_Train = mean(T_Train) % average time in sec of Model Building

%Plot our predicted results against the actual test values
x = 1:350; 
for i=1:24
    R1 = Y{i}(2,:);
    R2 = Testinputs2{i}(2,:);
    I1 = Y{i}(1,:);
    I2 = Testinputs2{i}(1,:);
    figure
    plot(Testtime{i},R2,'ro','MarkerSize',1.5)
    hold on
    plot(Testtime{i},R1,'k','LineWidth',1.5)
    legend('Test Data','GRU Model')
    xlabel('Time (Weeks)')
    ylabel('Recovered Individuals')
    figure
    plot(Testtime{i},I2,'ro','MarkerSize',1.5)
    hold on
    plot(Testtime{i},I1,'k','LineWidth',1.5)
    legend('Test Data','GRU Model')
    xlabel('Time (Weeks)')
    ylabel('Infected Individuals')
end
