clc; close all;
% rng(1)

%% Setup
% Defaults for axes
set(0, 'DefaultAxesFontSize', 15, 'DefaultAxesFontName', 'times', 'DefaultAxesFontWeight', 'bold', 'DefaultAxesLineWidth', 1.5)
% Defaults for plots
set(0, 'DefaultLineLineWidth', 2, 'DefaultAxesLineStyleOrder', '.-', 'DefaultLineMarkerSize', 20)
set(0, 'DefaultLineMarker', 'none')
% Defaults for text
set(0, 'DefaultTextInterpreter', 'latex', 'DefaultTextFontName', 'times', 'DefaultTextFontWeight', 'bold')
% Defaults for legend
set(0, 'DefaultLegendInterpreter', 'latex')
set(0,'DefaultFigureColormap',interp1([1 2], [33, 102, 172; 178, 24, 43] / 255, 1:0.01:2));
colors = [77, 175, 74; 55, 126, 184; 228, 26, 28]  / 255;


% Problem parameters
model = load_model_sioux_falls;
opts.model = model;
S = model.E + 1;    % state space, edges as states
A = S;              % action space, edges as actions
opts.S = S; opts.A = A;
opts.gamma = 0.8;
% M = normpdf(linspace(-1,1,S),0.5,1e-1)' + normpdf(linspace(-1,1,S),-0.5,1e-1)';
% M = circshift(M, randi(S));
% opts.M0 = M ./ sum(M); % initial M
Q0 = -Inf(S,A);  % initial Q
opts.s0 = randi(S);

% Construct action filter
edges = [model.destination, 1; model.edges];
filter = zeros(S,A);
for edge1 = 1:size(edges,1)
    for edge2 = 1:size(edges,1)
        filter(edge1,edge2) = edges(edge1,2) == edges(edge2,1);
    end
end
Q0(filter == 1) = 0;
opts.filter = filter;
% opts.Q0 = Q0;

% Training parameters
opts.epochs = 10;
opts.temp = 1e2;
% opts.temp = 1e-5;
% opts.temp = 1e-3;
opts.GLIE = false;
% opts.temp = 1e-2;
% opts.temp = 1e0;
% opts.soft = 0.2;
opts.soft = 0.5;
opts.tol_ip = 1e-2;
opts.tol_br = 1e-2;
% opts.step = @(k,l) l; % step size
% kappas = [1,3,5,9,13]; % on
% Ts = [50, 100, 125, 250, 5e2];
GLIE_hold = opts.GLIE;

% Helper functions
draw = @(p) find(cumsum(p) > rand(1), 1);
opts.softmax = @(q, h) exp((q-max(q))*h) ./ sum(exp((q-max(q))*h));
opts.get_softmax = @(q, h) draw(opts.softmax(q, h));
opts.r = @(s,a,M) ((s == 1) * 1e1 - (s ~= 1) .* (M(s,:,:).^2) * 1e5) .* ones(size(a)); %* ones(size(a));% + (s == 5) * (-1e3);
% opts.r = @(s,a,M) ((s == 1) * 1e1 - (s ~= 1) .* ((M(s)*S).^2)' * a * 1e-2); %* ones(size(a));% + (s == 5) * (-1e3);
err = @(M,m_opt) squeeze(sum((M-m_opt).^2, 1)); % sum(square((M_fixed-m_opt)*S))/S;

if ~exist('m_opt', 'var')
    % load('opt.mat')
		[m_opt, V_opt, q_opt] = opt(opts);
end
opts.m_opt = m_opt;

%% Vanilla FPI
opts.K = 2e2;
opts.T = 5e2;
% opts.TK = 5e3; %~= opts.T * opts.K * S * 10;

fprintf('Running FPI\n')
opts.FP = false; opts.OMD = false;
opts.policy = 'off';
[M_fpi, Q_fpi] = qmi(opts);
err_fpi = err(M_fpi, m_opt);
[V_fpi, u_fpi] = max(Q_fpi, [], 2);
erv_fpi = err(V_fpi, V_opt) / sum(V_opt.^2,1);
% err_V_fpi = err(V_fpi_arr, V_opt);
expl_fpi = expl(Q_fpi(:,:,1:end,:),opts);

% %% ER
fprintf('Running ER\n')
temp_hold = opts.temp;
opts.temp = 1e-3;
opts.GLIE = false;
[M_er, Q_er] = qmi(opts);
err_er = err(M_er, m_opt);
[V_er, u_er] = max(Q_er, [], 2);
erv_er = err(V_er, V_opt) / sum(V_opt.^2,1);
expl_er = expl(Q_er(:,:,1:end,:),opts);
opts.temp = temp_hold;
opts.GLIE = GLIE_hold;

%% FPI + FP
fprintf('Running FP\n')
opts.FP = true; opts.OMD = false;
[M_fp, Q_fp] = qmi(opts);
err_fp = err(M_fp, m_opt);
[V_fp, u_fp] = max(Q_fp, [], 2);
erv_fp = err(V_fp, V_opt) / sum(V_opt.^2,1);
expl_fp = expl(Q_fp(:,:,1:end,:),opts);

%% FPI + OMD
fprintf('Running OMD\n')
opts.FP = false; opts.OMD = true;
[M_omd, Q_omd] = qmi(opts);
err_omd = err(M_omd, m_opt);
[V_omd, u_omd] = max(Q_omd, [], 2);
erv_omd = err(V_omd, V_opt) / sum(V_opt.^2,1);
expl_omd = expl(Q_omd(:,:,1:end,:),opts);

%% SemiSGD
fprintf('Running SGD\n')
opts.T = opts.T * opts.K;
sync = 1e-3;
% opts.beta0 = sync/5; opts.alpha0 = sync;
opts.beta0 = sync; opts.alpha0 = sync;
[M_gd, Q_gd] = gd(opts);
err_gd = err(M_gd, m_opt);
[V_gd, u_gd] = max(Q_gd, [], 2);
erv_gd = err(V_gd, V_opt) / sum(V_opt.^2,1);
expl_gd = expl(Q_gd(:,:,1:end,:),opts);


%% MSE
f = figure; hold on;
skip = 1;
len = size(err_fpi,1);
varplot(1:skip:len,err_fpi(1:skip:end,:,:), 'marker', 'none', 'DisplayName', 'FPI')
varplot(1:skip:len,err_er(1:skip:end,:,:), 'marker', 'none', 'DisplayName', 'FPI+ER')
varplot(1:skip:len,err_fp(1:skip:end,:,:), 'marker', 'none', 'DisplayName', 'FPI+FP')
varplot(1:skip:len,err_omd(1:skip:end,:,:), 'marker', 'none', 'DisplayName', 'FPI+OMD')
varplot(1:skip:len,err_gd(1:skip:end,:,:), 'marker', 'none', 'DisplayName', 'SemiSGD')
axis = gca;
axis.YScale = 'log';
axis.YLim = [1e-3, 1e-1];
axis.XLim = [0, 200];
% f.Position(end) = 300;
for i = 1:5
	axis.Children(i).EdgeColor = 'none';
	axis.Children(i).FaceAlpha = 0.2;
	axis.Children(i).HandleVisibility = 'off';
end
legend('show', 'fontsize', 18, 'fontweight', 'bold')

%% Exploitability
f = figure; hold on;
ci = 0.95;
varplot(1:skip:len,expl_fpi(1:skip:end,:,:), 'ci', ci, 'marker', 'none', 'HandleVisibility', 'off')
varplot(1:skip:len,expl_er(1:skip:end,:,:), 'marker', 'none', 'HandleVisibility', 'off')
varplot(1:skip:len,expl_fp(1:skip:end,:,:), 'ci', ci, 'marker', 'none', 'HandleVisibility', 'off')
varplot(1:skip:len,expl_omd(1:skip:end,:,:), 'ci', ci, 'marker', 'none', 'HandleVisibility', 'off')
varplot(1:skip:len,expl_gd(1:skip:end,:,:), 'ci', ci, 'marker', 'none', 'HandleVisibility', 'off')
expl_mfe = expl(q_opt,opts);
axis = gca;
axis.XLim = [0, 200];
% axis.YLim = [1400, 3200];
% f.Position(end) = 300;
% axis.YScale = 'log';
for i = 1:5
	axis.Children(1).EdgeColor = 'none';
	axis.Children(1).FaceAlpha = 0.2;
	axis.Children(1).HandleVisibility = 'off';
end
% Plot the constant expl_mfe as dashed line for reference
plot(axis.XLim, [expl_mfe, expl_mfe], 'LineStyle', '--', 'marker', 'none', 'DisplayName', 'MFE')
legend('show', 'fontsize', 18, 'fontweight', 'bold')
% title('Exploitability')

% %% Off-policy QMI
% opts.policy = 'off';
% % results = struct;
% % for kappa = kappas
% % for T = Ts
%     % opts.kappa = kappa;
%     % opts.T = T;
%     % opts.K = opts.TK / T;
%     % opts.kappa = 1/S;
%     [err_off, expl_off, M_off, ~, ~] = qmi(opts);
%     % results.(sprintf('kappa%d', kappa)) = struct('err', err_off, 'expl', expl_off, 'M', M_off);
% %     results.(sprintf('T%d', T)) = struct('err', err_off, 'expl', expl_off, 'M', M_off);
% % end
%
% %% On-policy QMI
% opts.policy = 'on';
% results = struct;
% % for kappa = kappas
% % for T = Ts
%     % opts.kappa = kappa;
%     % opts.T = T;
%     % opts.K = opts.TK / T;
%     % opts.kappa = 1/S;
%     [err_on, expl_on, M_on, ~, ~] = qmi(opts);
%     % results.(sprintf('kappa%d', kappa)) = struct('err', err_on, 'expl', expl_on, 'M', M_on);
%     % results.(sprintf('T%d', T)) = struct('err', err_on, 'expl', expl_on, 'M', M_on);
% % end
%
%
% %% FPI
% opts.K = 50;
% opts.T = 5;
% [err_fpi, expl_fpi,M] = fpi(opts);
% err_fpi_line = repmat(err_fpi(end,:), [size(err_fpi,1),1]);
% expl_fpi_line = repmat(expl_fpi(end,:), [size(expl_fpi,1),1]);

%% Plot
% plot_list = {'error', 'expl'};
% save_flag = true;
% plot_results
