clc; close all;
addpath('ecai-algo1')

%% Setup
% Problem parameters
model = load_model_sioux_falls;
S = model.E + 1;      % state space, edges as states
% (4-1, 1-2, 1-3, 2-3, 2-4, 3-4)
A = S;      % action space, edges as actions
% M = sin(linspace(0,pi,S)); M = M';
gamma = 0.8;
M = ones(S,1);
M0 = M ./ sum(M); % initial M
Q_eval = Q_avg; % policy to evaluate
% M_eval = m_opt; % population to evaluate
% Q_eval = Q_opt; % policy to evaluate
s0 = 0;

% Construct action filter
edges = [model.destination, 1; model.edges];
filter = zeros(S,A);
for edge1 = 1:size(edges,1)
    for edge2 = 1:size(edges,1)
        filter(edge1,edge2) = edges(edge1,2) == edges(edge2,1);
    end
end
Q0(filter == 1) = 0;

% Training parameters
total = 1e3;
L = 1e0;
K = 1e1;
epochs = 100;% round(total / L / K);
K = K * S * A;
FP = false; % fictitious play

M_avg = zeros(S,1);
R_avg = 0;

% Helper functions
scale = @(arr) (arr - min(arr)) ./ (max(arr) - min(arr));
draw = @(p) find(cumsum(p) > rand(1), 1);
softmax = @(q, h) draw(exp((q-max(q))*h) / sum(exp((q-max(q))*h)));
bonus = @(s) 0.2 * (sin(4*pi*s*del) + 1);
r = @(s,a,M) (s == 1) * 1e1 + (s ~= 1) * (- (M(s))^2) * 1e5;% + (s == 5) * (-1e3);
P_det = @(s_con,a) a;

%% Learn
for e = 1:epochs
    fprintf('epoch: %d\n', e)
    s1 = 1;              % fixed initial state
    R = 0;
    M = M0;
    for k = 1:K
        % Sample
        s = s1;
        filtered_Q = Q(s, :); % Q_fixed as behavior policy
        filtered_Q(filter(s,:) ~= 1) = -inf;
        a = softmax(filtered_Q, 1e-4);
        s1 = a;
        
        % Get reward
        R = R + r(s,a,M_eval);
        M = (1 - 1/k) * M;
        M(s1) = M(s1) + 1/k;
    end
    R_avg = R_avg + R;
    M_avg = M_avg + M;
end

R_avg = R_avg / epochs;
M_avg = M_avg / epochs;

%% Plot
