clear all; close all; clc;
tic
rng(9); 
n = 3; % number of states
m = 4; % number of actions per state
num_iter = 10^8;
LL = space125(10,num_iter);
ll = length(LL);
gamma = 0.8; 
% Payoff matrices
R1 = zeros(m,m,n);
R2 = zeros(m,m,n);
for i = 1:n
   R1(:,:,i) = rand(m,m)*i^2; 
   R2(:,:,i) = -R1(:,:,i)';
end
% transition probabilities
pp = cell(n,n);
for i = 1:n
    for j1 = 1:m
        for j2 = 1:m
            ppp = rand(n,1)*rand + 0.1*ones(n,1); ppp = ppp/sum(ppp);
            for ii = 1:n
                pp{ii,i}(j1,j2) = ppp(ii);
            end
        end
    end
end
state = 1; % initial state

% Shapley iterations
SQ1 = zeros(m,m,n);
Sq1 = zeros(m,n);
Sv1 = zeros(1,n);
SQ2 = zeros(m,m,n);
Sq2 = zeros(m,n);
Sv2 = zeros(1,n);
L = 100;
for k = 1:L
    for state = 1:n
        Sv1(state) = minimax(SQ1(:,:,state));
        Sv2(state) = -Sv1(state);
    end
    for state = 1:n
        SQ1(:,:,state) = R1(:,:,state) + gamma*exp_cont_payoff(Sv1,pp,state);
        SQ2(:,:,state) = R2(:,:,state) + gamma*exp_cont_payoff(Sv2,pp,state)';
    end
end
Sv1_iter = kron(ones(ll,1),Sv1);
Sv2_iter = kron(ones(ll,1),Sv2);

% % Model-based
% % Initialization
% Q1 = R1;
% v1 = zeros(n,1);
% v1_iter = zeros(ll,n);
% vsum = zeros(n,1);
% vsum_iter = zeros(ll,n);
% p1 = zeros(m,n);
% Q2 = R2;
% v2 = zeros(n,1);
% v2_iter = zeros(ll,n);
% p2 = zeros(m,n);
% for i = 1:n
%     p1(:,i) = rand(m,1);
%     p1(:,i) = p1(:,i)/sum(p1(:,i));
%     p2(:,i) = rand(m,1);
%     p2(:,i) = p2(:,i)/sum(p2(:,i));
% end
% % Iterations
% pstate= zeros(n,1);
% count = zeros(1,n);
% for k = 1:num_iter
%     if mod(k,10^4) == 0
%         k
%     end
%     pstate(state) = pstate(state)+1;
%     a1 = best(Q1(:,:,state)*p2(:,state)); % uniform tie breaking rule
%     a2 = best(Q2(:,:,state)*p1(:,state)); % uniform tie breaking rule
%     p1(:,state) = p1(:,state) + alpha(count(state))*(a1-p1(:,state));
%     p2(:,state) = p2(:,state) + alpha(count(state))*(a2-p2(:,state));
%     Q1(:,:,state) = Q1(:,:,state) + beta(count(state))*...
%         (R1(:,:,state) + gamma*exp_cont_payoff(v1,pp,state) - Q1(:,:,state));
%     Q2(:,:,state) = Q2(:,:,state) + beta(count(state))*...
%         (R2(:,:,state) + gamma*exp_cont_payoff(v2,pp,state)' - Q2(:,:,state));
%     for i = 1:n
%         v1(i) = max(Q1(:,:,i)*p2(:,i));
%         v2(i) = max(Q2(:,:,i)*p1(:,i));
%         vsum(i) = v1(i)+v2(i);
%     end
%     if sum(find(LL==k))>0
%         jj = find(LL==k);
%         v1_iter(jj,:) = v1;
%         v2_iter(jj,:) = v2;
%         vsum_iter(jj,:) = vsum;
%     end
%     count(state)=count(state)+1;
%     state = next(pp,state,a1,a2);
% end
% 
% pstate = pstate/num_iter;
% 
% figure(1)
% semilogx(LL,zeros(ll,1),'k')
% hold on
% for i = 1:n
%     semilogx(LL,v1_iter(:,i)','r',...
%              LL,v2_iter(:,i)','b',...
%              LL,vsum_iter(:,i)','k',...
%              LL,Sv1_iter(:,i)','r:',...
%              LL,Sv2_iter(:,i)','b:')
% end
% xlabel('Stages')
% ylabel('Continuation Payoffs')
% axis([10 10^8 -20 20])



% Model-free Decentralized Learning
% Initialization
state = randi(n,1); % initial state
q11 = zeros(m,n);
v11 = zeros(n,1);
v11_iter = zeros(ll,n);
% p11 = zeros(m,n);
q12 = zeros(m,n);
v12 = zeros(n,1);
v12_iter = zeros(ll,n);
% p12 = zeros(m,n);
vsum1 = zeros(n,1);
vsum1_iter = zeros(ll,n);
% for i = 1:n
%     p11(:,i) = rand(m,1);
%     p11(:,i) = p11(:,i)/sum(p11(:,i));
%     p12(:,i) = rand(m,1);
%     p12(:,i) = p12(:,i)/sum(p12(:,i));
% end

% Iterations
pstate= zeros(n,1);
pstateaction= zeros(m,m,n);

countA1=zeros(m,n);
countA2=zeros(m,n);
countB1=zeros(1,n);
countB2=zeros(1,n);

for k = 1:num_iter
    if mod(k,10^4) == 0
        k
        k
    end
    pstate(state) = pstate(state)+1;
    a11 = bestExperiment(q11(:,state),0.02); % uniform tie breaking rule
    a12 = bestExperiment(q12(:,state),0.02); % uniform tie breaking rule
    i11=find(a11); i12=find(a12);
    pstateaction(i11,i12,state) = pstateaction(i11,i12,state)+1;
    
    n_state = next(pp,state,a11,a12);
    
    old_Q11_tmp = q11(:,state);
    q11(i11,state) = q11(i11,state) + alpha(countA1(i11,state))*...
        (R1(i11,i12,state) + gamma*v11(n_state) - q11(i11,state));
    v11(state) = v11(state) + beta(countB1(state))*(max(old_Q11_tmp)-v11(state));
    
    old_Q12_tmp = q12(:,state);
    q12(i12,state) = q12(i12,state) + alpha(countA2(i12,state))*...
        (R2(i11,i12,state) + gamma*v12(n_state) - q12(i12,state));
    v12(state) = v12(state) + beta(countB2(state))*(max(old_Q12_tmp)-v12(state));
    
    
    
    
%     p11(:,state) = p11(:,state) + alpha(countA(state))*(a11-p11(:,state));
%     p12(:,state) = p12(:,state) + alpha(countA(state))*(a12-p12(:,state));
%     n_state = next(pp,state,a11,a12);
%     Q11(i11,i12,state) = Q11(i11,i12,state) + beta(countB(i11,i12,state))*...
%         (R1(i11,i12,state) + gamma*v11(n_state) - Q11(i11,i12,state));
%     Q12(i12,i11,state) = Q12(i12,i11,state) + beta(countB(i11,i12,state))*...
%         (R2(i12,i11,state) + gamma*v12(n_state) - Q12(i12,i11,state));
    for i = 1:n
%         v11(i) = max(Q11(:,:,i)*p12(:,i));
%         v12(i) = max(Q12(:,:,i)*p11(:,i));
        vsum1(i) = v11(i)+v12(i);
    end
    if sum(find(LL==k))>0
         jj = find(LL==k);
         v11_iter(jj,:) = v11;
         v12_iter(jj,:) = v12;
         vsum1_iter(jj,:) = vsum1;
    end
    
    countA1(i11,state)=countA1(i11,state)+1;
    countA2(i12,state)=countA2(i12,state)+1;
    countB1(state)=countB1(state)+1;
    countB2(state)=countB2(state)+1;
    
    state = n_state;
end

pstate = pstate/num_iter;
pstateaction = pstateaction/num_iter;

figure(2)
semilogx(LL+1,zeros(ll,1),'k')
hold on
for i = 1:n
    semilogx(LL,v11_iter(:,i)','r',...
             LL,v12_iter(:,i)','b',...
             LL,vsum1_iter(:,i)','k',...
             LL,Sv1_iter(:,i)','r:',...
             LL,Sv2_iter(:,i)','b:')
end
axis([10 10^8 -20 20])
toc