clear all; close all; clc;
tic
rng(517);
trial_num = 6;
n = 5; % number of states
m = 3; % number of actions per state
num_iter = 5*10^9;
LL = space125(10,num_iter);
ll = length(LL);
gamma = 0.6;
rho_alpha = 0.9;
rho_beta = 1.0;
rho = 0.7;
DD = 1/(1-gamma);
% tau0 = num_iter*0.01/n;
tau0 = 2;


% Payoff matrices
R1 = zeros(m,m,n);
R2 = zeros(m,m,n);
for i = 1:n
    R1(:,:,i) = (2*rand(m,m)-1)*exp(i^2);
    %    R1(:,:,i) = rand(m,m)*exp(i^2);
    R2(:,:,i) = -R1(:,:,i)';
    
    max_R = max(max(abs(R1(:,:,i))));
    %    R1(:,:,i) = 2*(R1(:,:,i)./max_R-0.5);
    R1(:,:,i) = R1(:,:,i)./max_R;
    
    %    R2(:,:,i) = 2*(R2(:,:,i)./max_R-0.5);
    R2(:,:,i) = R2(:,:,i)./max_R;
end
% transition probabilities
pp = cell(n,n);
for i = 1:n
    for j1 = 1:m
        for j2 = 1:m
            ppp = rand(n,1)*rand + 0.1*ones(n,1); ppp = ppp/sum(ppp);
            for ii = 1:n
                pp{ii,i}(j1,j2) = ppp(ii);
            end
        end
    end
end
% state = 1; % initial state

% Shapley iterations
SQ1 = zeros(m,m,n);
Sq1 = zeros(m,n);
Sv1 = zeros(1,n);
SQ2 = zeros(m,m,n);
Sq2 = zeros(m,n);
Sv2 = zeros(1,n);
L = 300;
for k = 1:L
    for state = 1:n
        Sv1(state) = minimax(SQ1(:,:,state));
        %         Sv1(state) = softminimax(SQ1(:,:,state));
        Sv2(state) = -Sv1(state);
    end
    for state = 1:n
        SQ1(:,:,state) = R1(:,:,state) + gamma*exp_cont_payoff(Sv1,pp,state);
        SQ2(:,:,state) = R2(:,:,state) + gamma*exp_cont_payoff(Sv2,pp,state)';
    end
end
Sv1_iter = kron(ones(ll,1),Sv1);
Sv2_iter = kron(ones(ll,1),Sv2);





% Model-free Decentralized Learning
v11_iter = zeros(ll,n,trial_num);
v12_iter = zeros(ll,n,trial_num);
vsum1_iter = zeros(ll,n,trial_num);
pstate_tt= zeros(n,trial_num);
pstateaction_tt= zeros(m,m,n,trial_num);


 
parfor trial = 1:trial_num
    
    % Initialization
    state = randi(n,1); % initial state
    tau = tau0.*ones(n,1);
    q11 = zeros(m,n);
    v11 = zeros(n,1);
    v11_iter_tmp = zeros(ll,n);
    % p11 = zeros(m,n);
    q12 = zeros(m,n);
    v12 = zeros(n,1);
    v12_iter_tmp = zeros(ll,n);
    % p12 = zeros(m,n);
    vsum1 = zeros(n,1);
    vsum1_iter_tmp = zeros(ll,n);
    % for i = 1:n
    %     p11(:,i) = rand(m,1);
    %     p11(:,i) = p11(:,i)/sum(p11(:,i));
    %     p12(:,i) = rand(m,1);
    %     p12(:,i) = p12(:,i)/sum(p12(:,i));
    % end
    
    % Iterations
    pstate= zeros(n,1);
    pstateaction= zeros(m,m,n);
    
    countA1=ones(1,n);
    countA2=ones(1,n);
    countB1=ones(1,n);
    countB2=ones(1,n);
    
    
    % in case the number of iterations is not long enough
    % num_iter_new = 2.5*10^9;
    % num_iter_new_2 = 10*10^9;
    % LL = space125(10,num_iter_new);
    % ll = length(LL);
    % Sv1_iter = kron(ones(ll,1),Sv1);
    % Sv2_iter = kron(ones(ll,1),Sv2);
    
    
    for k = 1:num_iter
        % for k = num_iter+1:num_iter_new
        % for k = num_iter_new+1:num_iter_new_2
        if mod(k,10^4) == 0
            k
        end
        
        %     tau = tau0*beta(countB2(state)); % CountB2 and CountB1 are now the same
        tau_old = tau(state);
        %     tau(state) = (1-beta(countA1(state)))*tau(state)+beta(countA1(state))*(eta(countA1(state)));
        tau(state) = my_tau_to_zero(countA1(state), rho_alpha, rho, DD);
        
        pstate(state) = pstate(state)+1;
        [softmaxv11, a11, a11_prob] = mySoftmax( q11(:,state),tau_old ); % Softmax choice
        [softmaxv12, a12, a12_prob] = mySoftmax( q12(:,state),tau_old ); % Softmax choice
        i11=find(a11); i12=find(a12);
        pstateaction(i11,i12,state) = pstateaction(i11,i12,state)+1;
        
        if isnan(a11_prob) || isnan(a12_prob)
            break;
        end
        
        n_state = next(pp,state,a11,a12);
        
        
        alpha_bar = min(1,alpha(countA1(state),rho_alpha)/a11_prob);
        alpha_bar_2 = min(1,alpha2(countA2(state),rho_alpha)/a12_prob);
        
        q11(i11,state) = q11(i11,state) + alpha_bar*...
            (R1(i11,i12,state) + gamma*v11(n_state) - q11(i11,state)); %
        v11(state) = v11(state) + beta(countB1(state),rho_beta)*(softmaxv11-v11(state));
        
        q12(i12,state) = q12(i12,state) + alpha_bar_2*...
            (R2(i12,i11,state) + gamma*v12(n_state) - q12(i12,state)); %
        v12(state) = v12(state) + beta2(countB2(state),rho_beta)*(softmaxv12-v12(state));
        
        
        
        %     p11(:,state) = p11(:,state) + alpha(countA(state))*(a11-p11(:,state));
        %     p12(:,state) = p12(:,state) + alpha(countA(state))*(a12-p12(:,state));
        %     n_state = next(pp,state,a11,a12);
        %     Q11(i11,i12,state) = Q11(i11,i12,state) + beta(countB(i11,i12,state))*...
        %         (R1(i11,i12,state) + gamma*v11(n_state) - Q11(i11,i12,state));
        %     Q12(i12,i11,state) = Q12(i12,i11,state) + beta(countB(i11,i12,state))*...
        %         (R2(i12,i11,state) + gamma*v12(n_state) - Q12(i12,i11,state));
        for i = 1:n
            %         v11(i) = max(Q11(:,:,i)*p12(:,i));
            %         v12(i) = max(Q12(:,:,i)*p11(:,i));
            vsum1(i) = v11(i)+v12(i);
        end
        if sum(find(LL==k))>0
            jj = find(LL==k);
            v11_iter_tmp(jj,:) = v11;
            v12_iter_tmp(jj,:) = v12;
            vsum1_iter_tmp(jj,:) = vsum1;
        end
        
        countA1(state)=countA1(state)+1;
        countA2(state)=countA2(state)+1;
        countB1(state)=countB1(state)+1;
        countB2(state)=countB2(state)+1;
        
        state = n_state;
    end
    
    v11_iter(:,:,trial) = v11_iter_tmp;
    v12_iter(:,:,trial) = v12_iter_tmp;
    vsum1_iter(:,:,trial) = vsum1_iter_tmp;
    
    pstate_tt(:,trial) = pstate/num_iter;
    pstateaction_tt(:,:,:,trial) = pstateaction./num_iter;
    
end




% plot_ind = 1:(length(LL)-5);
plot_ind = 1:length(LL);
x_label_ind = ceil(linspace(1,length(plot_ind),5));
v11_iter_mean = mean(v11_iter,3);
v11_iter_std = sqrt(var(v11_iter,0,3));
v12_iter_mean = mean(v12_iter,3);
v12_iter_std = sqrt(var(v12_iter,0,3));
vsum_iter_mean = v11_iter_mean+v12_iter_mean;


figure
x_vector = log([LL(plot_ind), fliplr(LL(plot_ind))]);
LL_plot = log(LL(plot_ind));
p1_color = [255, 177, 168]./255;
p2_color = [189, 200, 255]./255;
plot(LL_plot,zeros(length(plot_ind),1),'k')
set(gca,'XTick',LL_plot(x_label_ind),'XTickLabel',LL(x_label_ind));


hold on
for i = 1:n
    data_err_1 =  (v11_iter_std(plot_ind,i))';
    data_mean_1 = (v11_iter_mean(plot_ind,i))';
    patch_1 = fill(x_vector, [data_mean_1+data_err_1,fliplr(data_mean_1-data_err_1)], p1_color);
    set(patch_1, 'edgecolor', 'none');
    set(patch_1, 'FaceAlpha', 0.3);
    
    data_err_2 =  (v12_iter_std(plot_ind,i))';
    data_mean_2 = (v12_iter_mean(plot_ind,i))';
    patch_2 = fill(x_vector, [data_mean_2+data_err_2,fliplr(data_mean_2-data_err_2)], p2_color);
    set(patch_2, 'edgecolor', 'none');
    set(patch_2, 'FaceAlpha', 0.3);
    
    plot(LL_plot,data_mean_1,'r',...
        LL_plot,data_mean_2,'b',...
        LL_plot,vsum_iter_mean(plot_ind,i)','k',...
        LL_plot,Sv1_iter(plot_ind,i)','r:',...
        LL_plot,Sv2_iter(plot_ind,i)','b:','linewidth',2)
end
xlim([min(LL_plot) inf])
ylim([-0.75 0.8])
xlabel('Iterations')
ylabel('Value Functions')
set(gca,'FontSize',16)
toc

save results_multi_trial_n5_m3_data_new_alg_tau_to_zero_tau_bar_0068_7