clear all; close all; clc;
tic
rng(517); 
n = 5; % number of states
m = 3; % number of actions per state
num_iter = 3*10^9;
LL = space125(10,num_iter);
ll = length(LL);
gamma = 0.6; 
rho_alpha = 0.9;
rho_beta = 1.0;
rho = 0.7; 
DD = 1/(1-gamma);
% tau0 = num_iter*0.01/n;
tau0 = 2;


% Payoff matrices
R1 = zeros(m,m,n);
R2 = zeros(m,m,n);
for i = 1:n
   R1(:,:,i) = (2*rand(m,m)-1)*exp(i^2); 
%    R1(:,:,i) = rand(m,m)*exp(i^2); 
   R2(:,:,i) = -R1(:,:,i)';
   
   max_R = max(max(abs(R1(:,:,i))));
%    R1(:,:,i) = 2*(R1(:,:,i)./max_R-0.5);
   R1(:,:,i) = R1(:,:,i)./max_R;
   
%    R2(:,:,i) = 2*(R2(:,:,i)./max_R-0.5);
   R2(:,:,i) = R2(:,:,i)./max_R;
end
% transition probabilities
pp = cell(n,n);
for i = 1:n
    for j1 = 1:m
        for j2 = 1:m
            ppp = rand(n,1)*rand + 0.1*ones(n,1); ppp = ppp/sum(ppp);
            for ii = 1:n
                pp{ii,i}(j1,j2) = ppp(ii);
            end
        end
    end
end
% state = 1; % initial state

% Shapley iterations
SQ1 = zeros(m,m,n);
Sq1 = zeros(m,n);
Sv1 = zeros(1,n);
SQ2 = zeros(m,m,n);
Sq2 = zeros(m,n);
Sv2 = zeros(1,n);
L = 300;
for k = 1:L
    for state = 1:n
        Sv1(state) = minimax(SQ1(:,:,state));
%         Sv1(state) = softminimax(SQ1(:,:,state));
        Sv2(state) = -Sv1(state);
    end
    for state = 1:n
        SQ1(:,:,state) = R1(:,:,state) + gamma*exp_cont_payoff(Sv1,pp,state);
        SQ2(:,:,state) = R2(:,:,state) + gamma*exp_cont_payoff(Sv2,pp,state)';
    end
end
Sv1_iter = kron(ones(ll,1),Sv1);
Sv2_iter = kron(ones(ll,1),Sv2);

% % Model-based
% % Initialization
% Q1 = R1;
% v1 = zeros(n,1);
% v1_iter = zeros(ll,n);
% vsum = zeros(n,1);
% vsum_iter = zeros(ll,n);
% p1 = zeros(m,n);
% Q2 = R2;
% v2 = zeros(n,1);
% v2_iter = zeros(ll,n);
% p2 = zeros(m,n);
% for i = 1:n
%     p1(:,i) = rand(m,1);
%     p1(:,i) = p1(:,i)/sum(p1(:,i));
%     p2(:,i) = rand(m,1);
%     p2(:,i) = p2(:,i)/sum(p2(:,i));
% end
% % Iterations
% pstate= zeros(n,1);
% count = zeros(1,n);
% for k = 1:num_iter
%     if mod(k,10^4) == 0
%         k
%     end
%     pstate(state) = pstate(state)+1;
%     a1 = best(Q1(:,:,state)*p2(:,state)); % uniform tie breaking rule
%     a2 = best(Q2(:,:,state)*p1(:,state)); % uniform tie breaking rule
%     p1(:,state) = p1(:,state) + alpha(count(state))*(a1-p1(:,state));
%     p2(:,state) = p2(:,state) + alpha(count(state))*(a2-p2(:,state));
%     Q1(:,:,state) = Q1(:,:,state) + beta(count(state))*...
%         (R1(:,:,state) + gamma*exp_cont_payoff(v1,pp,state) - Q1(:,:,state));
%     Q2(:,:,state) = Q2(:,:,state) + beta(count(state))*...
%         (R2(:,:,state) + gamma*exp_cont_payoff(v2,pp,state)' - Q2(:,:,state));
%     for i = 1:n
%         v1(i) = max(Q1(:,:,i)*p2(:,i));
%         v2(i) = max(Q2(:,:,i)*p1(:,i));
%         vsum(i) = v1(i)+v2(i);
%     end
%     if sum(find(LL==k))>0
%         jj = find(LL==k);
%         v1_iter(jj,:) = v1;
%         v2_iter(jj,:) = v2;
%         vsum_iter(jj,:) = vsum;
%     end
%     count(state)=count(state)+1;
%     state = next(pp,state,a1,a2);
% end
% 
% pstate = pstate/num_iter;
% 
% figure(1)
% semilogx(LL,zeros(ll,1),'k')
% hold on
% for i = 1:n
%     semilogx(LL,v1_iter(:,i)','r',...
%              LL,v2_iter(:,i)','b',...
%              LL,vsum_iter(:,i)','k',...
%              LL,Sv1_iter(:,i)','r:',...
%              LL,Sv2_iter(:,i)','b:')
% end
% xlabel('Stages')
% ylabel('Continuation Payoffs')
% axis([10 10^8 -20 20])



% Model-free Decentralized Learning
% Initialization
state = randi(n,1); % initial state
tau = tau0.*ones(n,1);
q11 = zeros(m,n);
v11 = zeros(n,1);
v11_iter = zeros(ll,n);
% p11 = zeros(m,n);
q12 = zeros(m,n);
v12 = zeros(n,1);
v12_iter = zeros(ll,n);
% p12 = zeros(m,n);
vsum1 = zeros(n,1);
vsum1_iter = zeros(ll,n);
% for i = 1:n
%     p11(:,i) = rand(m,1);
%     p11(:,i) = p11(:,i)/sum(p11(:,i));
%     p12(:,i) = rand(m,1);
%     p12(:,i) = p12(:,i)/sum(p12(:,i));
% end

% Iterations
pstate= zeros(n,1);
pstateaction= zeros(m,m,n);

countA1=ones(1,n);
countA2=ones(1,n);
countB1=ones(1,n);
countB2=ones(1,n);


% in case the number of iterations is not long enough
% num_iter_new = 2.5*10^9;
% num_iter_new_2 = 10*10^9;
% LL = space125(10,num_iter_new);
% ll = length(LL);
% Sv1_iter = kron(ones(ll,1),Sv1);
% Sv2_iter = kron(ones(ll,1),Sv2);


for k = 1:num_iter
% for k = num_iter+1:num_iter_new    
% for k = num_iter_new+1:num_iter_new_2    
    if mod(k,10^4) == 0
        k
    end
    
%     tau = tau0*beta(countB2(state)); % CountB2 and CountB1 are now the same
    tau_old = tau(state);
%     tau(state) = (1-beta(countA1(state)))*tau(state)+beta(countA1(state))*(eta(countA1(state)));
    tau(state) = my_tau_to_zero(countA1(state), rho_alpha, rho, DD);
    
    pstate(state) = pstate(state)+1;
    [softmaxv11, a11, a11_prob] = mySoftmax( q11(:,state),tau_old ); % Softmax choice
    [softmaxv12, a12, a12_prob] = mySoftmax( q12(:,state),tau_old ); % Softmax choice
    i11=find(a11); i12=find(a12);
    pstateaction(i11,i12,state) = pstateaction(i11,i12,state)+1;
    
    if isnan(a11_prob) || isnan(a12_prob)
        break;
    end
    
    n_state = next(pp,state,a11,a12);
    
    
    alpha_bar = min(1,alpha(countA1(state),rho_alpha)/a11_prob);
    alpha_bar_2 = min(1,alpha2(countA2(state),rho_alpha)/a12_prob);
    
    q11(i11,state) = q11(i11,state) + alpha_bar*...
        (R1(i11,i12,state) + gamma*v11(n_state) - q11(i11,state)); % 
    v11(state) = v11(state) + beta(countB1(state),rho_beta)*(softmaxv11-v11(state));
    
    q12(i12,state) = q12(i12,state) + alpha_bar_2*...
        (R2(i12,i11,state) + gamma*v12(n_state) - q12(i12,state)); % 
    v12(state) = v12(state) + beta2(countB2(state),rho_beta)*(softmaxv12-v12(state));
    
        
    
%     p11(:,state) = p11(:,state) + alpha(countA(state))*(a11-p11(:,state));
%     p12(:,state) = p12(:,state) + alpha(countA(state))*(a12-p12(:,state));
%     n_state = next(pp,state,a11,a12);
%     Q11(i11,i12,state) = Q11(i11,i12,state) + beta(countB(i11,i12,state))*...
%         (R1(i11,i12,state) + gamma*v11(n_state) - Q11(i11,i12,state));
%     Q12(i12,i11,state) = Q12(i12,i11,state) + beta(countB(i11,i12,state))*...
%         (R2(i12,i11,state) + gamma*v12(n_state) - Q12(i12,i11,state));
    for i = 1:n
%         v11(i) = max(Q11(:,:,i)*p12(:,i));
%         v12(i) = max(Q12(:,:,i)*p11(:,i));
        vsum1(i) = v11(i)+v12(i);
    end
    if sum(find(LL==k))>0
         jj = find(LL==k);
         v11_iter(jj,:) = v11;
         v12_iter(jj,:) = v12;
         vsum1_iter(jj,:) = vsum1;
    end
    
    countA1(state)=countA1(state)+1;
    countA2(state)=countA2(state)+1;
    countB1(state)=countB1(state)+1;
    countB2(state)=countB2(state)+1;
    
    state = n_state;
end

pstate = pstate/num_iter;
pstateaction = pstateaction/num_iter;



figure
semilogx(LL+1,zeros(ll,1),'k')
hold on
for i = 1:n
    semilogx(LL,v11_iter(:,i)','r',...
             LL,v12_iter(:,i)','b',...
             LL,vsum1_iter(:,i)','k',...
             LL,Sv1_iter(:,i)','r:',...
             LL,Sv2_iter(:,i)','b:','linewidth',2)
end
xlim([min(LL) inf])
xlabel('Iterations')
ylabel('Value Functions')
set(gca,'FontSize',16)
toc

save results_n5_m3_data_new_alg_tau_to_zero