S = 70;
A = 50;
d = 5;
gamma = 0.9;

deff = 15;
std_noise_reward = 1e-2;
tol = 1e-4;
N_steps = 50;
N_experiments = 100;

uniform_list = zeros(N_experiments,N_steps);
leveraged_list = zeros(N_experiments,N_steps);

for i_experiment = 1:N_experiments
    if mod(i_experiment,5)==1
        [Qstar,r,P] = Q_generation(S,A,d,gamma);
    end
    error_inf_uniform = zeros(1,N_steps);
    error_inf_leveraged = zeros(1,N_steps);
    
    %% VI ITERATION with uniform anchors
    curr_epoch = 0;
    V = zeros([S,1]);
    
    N_samples = 10;
    while(curr_epoch < N_steps)
        N_samples = ceil(N_samples*1.1);
    
        anchor_states = randperm(S,deff);
        anchor_actions = randperm(A,deff);
    
        curr_epoch = curr_epoch + 1;
        Q = zeros([S,A]);
        for i_s = 1:length(anchor_states)
            for i_a = 1:A
                sampled_rewards = r(anchor_states(i_s),i_a) + randn([1,N_samples])*std_noise_reward;
                Q(anchor_states(i_s),i_a) = mean(sampled_rewards) + gamma*squeeze(P(anchor_states(i_s),i_a,:))'*V;
            end
        end
        for i_a = 1:length(anchor_actions)
            for i_s = 1:S
                sampled_rewards = r(i_s,anchor_actions(i_a)) + randn([1,N_samples])*std_noise_reward;
                Q(i_s,anchor_actions(i_a)) = mean(sampled_rewards) + gamma*squeeze(P(i_s,anchor_actions(i_a),:))'*V;
            end
        end
        
        if curr_epoch == 1
            epsilon_parameter = 1e-3;
        else
            epsilon_parameter = max(error_inf_uniform(curr_epoch-1)*1e-3);
        end
        my_pinv_mat = my_pinv(epsilon_parameter, Q(anchor_states,anchor_actions));
    
        for i_s = 1:S
            if any(i_s==anchor_states)
                continue;
            end
            for i_a = 1:A
                if any(i_a==anchor_actions)
                    continue;
                end
                Q(i_s,i_a) = Q(i_s,anchor_actions)*my_pinv_mat*Q(anchor_states,i_a);
            end
        end
        Qstar_t = zeros([S,A]);
        for i_a = 1:A
            for i_s = 1:S
                Qstar_t(i_s,i_a) = r(i_s,i_a) + gamma*squeeze(P(i_s,i_a,:))'*V;
            end
        end
    
        V = max(Q')';
        error_inf_uniform(curr_epoch) = max(abs(Qstar(:) - Q(:)));
    
    end
    
    
    
    %% VI ITERATION with leveraged anchors
    
    V = zeros([S,1]);
    curr_epoch = 0;
    N_samples = 10;
    
    while(curr_epoch < N_steps)
        N_samples = ceil(N_samples*1.1);
        curr_epoch = curr_epoch + 1;
    
        % find anchors
        Qtmp = r;
        for i = 1:S
            Qtmp = Qtmp + gamma*P(:,:,i)*V(i);
        end
        [UQ,SigmaQ,WQ] = svds(Qtmp,d);
        UQd = UQ(:,1:d);
        WQd = WQ(:,1:d);
        ellU = zeros([1,S]);
        ellW = zeros([1,A]);
        for i = 1:S
            ellU(i) = norm(UQd(i,:));
        end
        for j = 1:A
            ellW(j) = norm(WQd(j,:));
        end
        
        probsU = ellU/sum(ellU);
        [~,anchor_states] = maxk(probsU,deff);
        D = diag(1./sqrt(probsU(anchor_states)));
        
        probsW = ellW/sum(ellW);
        [~,anchor_actions] = maxk(probsW,deff);
        DW = diag(1./sqrt(probsW(anchor_actions)));
    
    
        % CUR approximation
        Q = zeros([S,A]);
        for i_s = 1:length(anchor_states)
            for i_a = 1:A
                sampled_rewards = r(anchor_states(i_s),i_a) + randn([1,N_samples])*std_noise_reward;
                Q(anchor_states(i_s),i_a) = mean(sampled_rewards) + gamma*squeeze(P(anchor_states(i_s),i_a,:))'*V;
            end
        end
        for i_a = 1:length(anchor_actions)
            for i_s = 1:S
                sampled_rewards = r(i_s,anchor_actions(i_a)) + randn([1,N_samples])*std_noise_reward;
                Q(i_s,anchor_actions(i_a)) = mean(sampled_rewards) + gamma*squeeze(P(i_s,anchor_actions(i_a),:))'*V;
            end
        end
    
        if curr_epoch == 1
            epsilon_parameter = 1e-3;
        else
            epsilon_parameter = error_inf_leveraged(curr_epoch-1)*1e-3;
        end
        my_pinv_mat = my_pinv(epsilon_parameter, D*Q(anchor_states,anchor_actions)*DW);
    
    
        for i_s = 1:S
            if any(i_s==anchor_states)
                continue;
            end
            for i_a = 1:A
                if any(i_a==anchor_actions)
                    continue;
                end
                Q(i_s,i_a) = Q(i_s,anchor_actions)*DW*my_pinv_mat*D*Q(anchor_states,i_a);
            end
        end
        Qstar_t = zeros([S,A]);
        for i_a = 1:A
            for i_s = 1:S
                Qstar_t(i_s,i_a) = r(i_s,i_a) + gamma*squeeze(P(i_s,i_a,:))'*V;
            end
        end
    
        V = max(Q')';
        error_inf_leveraged(curr_epoch) = max(abs(Qstar(:) - Q(:)));
    end


    uniform_list(i_experiment,:) = error_inf_uniform(:);
    leveraged_list(i_experiment,:) = error_inf_leveraged(:);


end

mean_uniform = mean(uniform_list);
mean_leveraged = mean(leveraged_list);

std_uniform = std(uniform_list);
std_leveraged = std(leveraged_list);



%%
figure()
set(gcf,'renderer','Painters')
x = 1:N_steps;
x2 = [x, fliplr(x)];
curve1 = mean_leveraged + std_leveraged;
curve2 = mean_leveraged - std_leveraged;
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'b','FaceAlpha',0.05,'LineStyle','none');
hold on

curve1 = mean_uniform + std_uniform;
curve2 = mean_uniform - std_uniform;
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'r','FaceAlpha',0.05,'LineStyle','none');
ylim([0,1.5*max(mean_uniform)])
hold on
plot(mean_leveraged)
hold on
plot(mean_uniform);
hold on
xlabel("iteration number $t$", 'interpreter', 'latex','FontSize',12)
ylabel("$\Vert \widehat{Q}^{(t)} - Q^\star \Vert_{\infty}$", 'interpreter', 'latex','FontSize',16)
legend("","","oracle anchors","uniform anchors",'interpreter','latex','FontSize',12)
xlim([1,N_steps])
set(gca,'TickLabelInterpreter','latex')

