clear all; close all; clc;
tic
rng(9);
n = 10; % number of states
m = 4; % number of actions per state
num_iter = 10^8;
LL = space125(10,num_iter);
ll = length(LL);
gamma = 0.8; 
% Payoff matrices
R1 = zeros(m,m,n);
R2 = zeros(m,m,n);
for i = 1:n
   R1(:,:,i) = rand(m,m)*i^2; 
   R2(:,:,i) = -R1(:,:,i)';
end
% transition probabilities
pp = cell(n,n);
for i = 1:n
    for j1 = 1:m
        for j2 = 1:m
            ppp = rand(n,1)*rand + 0.1*ones(n,1); ppp = ppp/sum(ppp);
            for ii = 1:n
                pp{ii,i}(j1,j2) = ppp(ii);
            end
        end
    end
end
state = 1; % initial state

% Shapley iterations
SQ1 = zeros(m,m,n);
Sv1 = zeros(1,n);
SQ2 = zeros(m,m,n);
Sv2 = zeros(1,n);
L = 100;
for k = 1:L
    for state = 1:n
        Sv1(state) = minimax(SQ1(:,:,state));
        Sv2(state) = -Sv1(state);
    end
    for state = 1:n
        SQ1(:,:,state) = R1(:,:,state) + gamma*exp_cont_payoff(Sv1,pp,state);
        SQ2(:,:,state) = R2(:,:,state) + gamma*exp_cont_payoff(Sv2,pp,state)';
    end
end
Sv1_iter = kron(ones(ll,1),Sv1);
Sv2_iter = kron(ones(ll,1),Sv2);

% Model-based
% Initialization
% Q1 = R1;
% v1 = zeros(n,1);
% v1_iter = zeros(ll,n);
% vsum = zeros(n,1);
% vsum_iter = zeros(ll,n);
% p1 = zeros(m,n);
% Q2 = R2;
% v2 = zeros(n,1);
% v2_iter = zeros(ll,n);
% p2 = zeros(m,n);
% for i = 1:n
%     p1(:,i) = rand(m,1);
%     p1(:,i) = p1(:,i)/sum(p1(:,i));
%     p2(:,i) = rand(m,1);
%     p2(:,i) = p2(:,i)/sum(p2(:,i));
% end
% % Iterations
% pstate= zeros(n,1);
% count = zeros(1,n);
% for k = 1:num_iter
%     if mod(k,10^4) == 0
%         k
%     end
%     pstate(state) = pstate(state)+1;
%     a1 = best(Q1(:,:,state)*p2(:,state)); % uniform tie breaking rule
%     a2 = best(Q2(:,:,state)*p1(:,state)); % uniform tie breaking rule
%     p1(:,state) = p1(:,state) + alpha(count(state))*(a1-p1(:,state));
%     p2(:,state) = p2(:,state) + alpha(count(state))*(a2-p2(:,state));
%     Q1(:,:,state) = Q1(:,:,state) + beta(count(state))*...
%         (R1(:,:,state) + gamma*exp_cont_payoff(v1,pp,state) - Q1(:,:,state));
%     Q2(:,:,state) = Q2(:,:,state) + beta(count(state))*...
%         (R2(:,:,state) + gamma*exp_cont_payoff(v2,pp,state)' - Q2(:,:,state));
%     for i = 1:n
%         v1(i) = max(Q1(:,:,i)*p2(:,i));
%         v2(i) = max(Q2(:,:,i)*p1(:,i));
%         vsum(i) = v1(i)+v2(i);
%     end
%     if sum(find(LL==k))>0
%         jj = find(LL==k);
%         v1_iter(jj,:) = v1;
%         v2_iter(jj,:) = v2;
%         vsum_iter(jj,:) = vsum;
%     end
%     count(state)=count(state)+1;
%     state = next(pp,state,a1,a2);
% end
% 
% pstate = pstate/num_iter;
% 
% figure(1)
% semilogx(LL,zeros(ll,1),'k')
% hold on
% for i = 1:n
%     semilogx(LL,v1_iter(:,i)','r',...
%              LL,v2_iter(:,i)','b',...
%              LL,vsum_iter(:,i)','k',...
%              LL,Sv1_iter(:,i)','r:',...
%              LL,Sv2_iter(:,i)','b:')
% end
% xlabel('Stages')
% ylabel('Continuation Payoffs')
% axis([10 10^8 -20 20])

% Model-free
% Initialization
Q11 = zeros(m,m,n);
v11 = zeros(n,1);
v11_iter = zeros(ll,n);
p11 = zeros(m,n);
Q12 = zeros(m,m,n);
v12 = zeros(n,1);
v12_iter = zeros(ll,n);
p12 = zeros(m,n);
vsum1 = zeros(n,1);
vsum1_iter = zeros(ll,n);
for i = 1:n
    p11(:,i) = rand(m,1);
    p11(:,i) = p11(:,i)/sum(p11(:,i));
    p12(:,i) = rand(m,1);
    p12(:,i) = p12(:,i)/sum(p12(:,i));
end
% Iterations
pstate1= zeros(n,1);
pstateaction1= zeros(m,m,n);
countA=zeros(1,n);
countB=zeros(m,m,n);
for k = 1:num_iter
    if mod(k,10^4) == 0
        k
        k
    end
    pstate1(state) = pstate1(state)+1;
    a11 = bestExperiment(Q11(:,:,state)*p12(:,state),0.02); % uniform tie breaking rule
    a12 = bestExperiment(Q12(:,:,state)*p11(:,state),0.02); % uniform tie breaking rule
    i11=find(a11); i12=find(a12);
    pstateaction1(i11,i12,state) = pstateaction1(i11,i12,state)+1;
    p11(:,state) = p11(:,state) + alpha(countA(state))*(a11-p11(:,state));
    p12(:,state) = p12(:,state) + alpha(countA(state))*(a12-p12(:,state));
    n_state = next(pp,state,a11,a12);
    Q11(i11,i12,state) = Q11(i11,i12,state) + beta(countB(i11,i12,state))*...
        (R1(i11,i12,state) + gamma*v11(n_state) - Q11(i11,i12,state));
    Q12(i12,i11,state) = Q12(i12,i11,state) + beta(countB(i11,i12,state))*...
        (R2(i12,i11,state) + gamma*v12(n_state) - Q12(i12,i11,state));
    for i = 1:n
        v11(i) = max(Q11(:,:,i)*p12(:,i));
        v12(i) = max(Q12(:,:,i)*p11(:,i));
        vsum1(i) = v11(i)+v12(i);
    end
    if sum(find(LL==k))>0
         jj = find(LL==k);
         v11_iter(jj,:) = v11;
         v12_iter(jj,:) = v12;
         vsum1_iter(jj,:) = vsum1;
    end
    countA(state)=countA(state)+1;
    countB(i11,i12,state)=countB(i11,i12,state)+1;
    state = n_state;
end

pstate1 = pstate1/num_iter;
pstateaction1 = pstateaction1/num_iter;

figure(2)
semilogx(LL+1,zeros(ll,1),'k')
hold on
for i = 1:n
    semilogx(LL,v11_iter(:,i)','r',...
             LL,v12_iter(:,i)','b',...
             LL,vsum1_iter(:,i)','k',...
             LL,Sv1_iter(:,i)','r:',...
             LL,Sv2_iter(:,i)','b:')
end
axis([10 10^8 -20 20])
toc