S = 1000;
A = 1000;
d = 4;
gamma = 0.1;

tau = 5;
deff = 10;
std_noise_reward = 1e-2;
tol = 1e-6;
 
[Qstar,r,P] = Q_generation(S,A,d,gamma);

%%
N_steps = 20;
N_experiments = 3; 

uniform_list_inf = zeros(N_experiments,N_steps);
uniform_list_fro = zeros(N_experiments,N_steps);
leveraged_list_inf = zeros(N_experiments,N_steps);
leveraged_list_fro = zeros(N_experiments,N_steps);
vanilla_list_inf = zeros(N_experiments,N_steps);
vanilla_list_fro = zeros(N_experiments,N_steps);

for i_experiment = 1:N_experiments
    policy0 = randi(A,[1,S]);

    %% LEVERAGED PI ITERATION
    N_samples = 10;
    samples_epoch_leveraged = zeros([1,N_steps]);
    policy = policy0;

    for curr_epoch = 1:N_steps
        disp([1,i_experiment,curr_epoch])
        N_samples = ceil(N_samples*1.15);
    
        T_SVD = N_samples*(S+A)*deff;
        visited_states = randi(S*A,[1,T_SVD]);
        acc_states = accumarray(visited_states(:),1);
        acc_states(length(acc_states):S*A) = 0;

        Qtmp = zeros([S,A]);
        for i_s = 1:S
            disp(i_s)
            for i_a = 1:A
                % Learning part
                N_tmp = acc_states(i_a+(i_s-1)*A);
                if N_tmp == 0
                    continue;
                end
                for i_N = 1:N_tmp
                    state_tmp = i_s;
                    action_tmp = i_a;
                    Qtmp(i_s,i_a) = Qtmp(i_s,i_a) + r(state_tmp,action_tmp) + randn()*std_noise_reward;
                    for i_tau = 2:tau
                        if i_tau==2
                            Ptmp = squeeze(P(state_tmp,action_tmp,:));
                        else
                            Ptmp = squeeze(P(state_tmp,policy(state_tmp),:));
                        end
                        cum_probs = cumsum(Ptmp);
                        state_tmp = find(cum_probs>rand(),1);
                        Qtmp(i_s,i_a) = Qtmp(i_s,i_a) + gamma^(i_tau-1)*( r(state_tmp,policy(state_tmp))+ randn()*std_noise_reward );                    
                    end
                end
            end
        end
        Qtmp = Qtmp*(S*A)/T_SVD;

    
        [UQ,SigmaQ,WQ] = svds(Qtmp,d);
        UQd = UQ(:,1:d);
        WQd = WQ(:,1:d);
        ellU = zeros([1,S]);
        ellW = zeros([1,A]);
        for i = 1:S
            ellU(i) = norm(UQd(i,:));
        end
        for j = 1:A
            ellW(j) = norm(WQd(j,:));
        end
        
        probsU = ellU/sum(ellU);
        [~,anchor_states] = maxk(probsU,deff);
        D = diag(1./sqrt(probsU(anchor_states)));
        
        probsW = ellW/sum(ellW);
        [~,anchor_actions] = maxk(probsW,deff);
        DW = diag(1./sqrt(probsW(anchor_actions)));
    
        Q = zeros([S,A]);
        for i_s = 1:length(anchor_states)
            disp(i_s)
            for i_a = 1:A
                for i_N = 1:N_samples
                    state_tmp = anchor_states(i_s);
                    action_tmp = i_a;
                    Q(anchor_states(i_s),i_a) = Q(anchor_states(i_s),i_a) + r(state_tmp,action_tmp) + randn()*std_noise_reward;
                    for i_tau = 2:tau
                        if i_tau==2
                            Ptmp = squeeze(P(state_tmp,action_tmp,:));
                        else
                            Ptmp = squeeze(P(state_tmp,policy(state_tmp),:));
                        end
                        cum_probs = cumsum(Ptmp);
                        state_tmp = find(cum_probs>rand(),1);
                        Q(anchor_states(i_s),i_a) = Q(anchor_states(i_s),i_a) + gamma^(i_tau-1)*( r(state_tmp,policy(state_tmp))+ randn()*std_noise_reward );                    
                    end
                end
            end
        end
        for i_a = 1:length(anchor_actions)
            disp(i_a)    
            for i_s = 1:S
                for i_N = 1:N_samples
                    state_tmp = i_s;
                    action_tmp = anchor_actions(i_a);
                    Q(i_s,anchor_actions(i_a)) = Q(i_s,anchor_actions(i_a)) + r(state_tmp,action_tmp) + randn()*std_noise_reward;
                    for i_tau = 2:tau
                        if i_tau==2
                            Ptmp = squeeze(P(state_tmp,action_tmp,:));
                        else
                            Ptmp = squeeze(P(state_tmp,policy(state_tmp),:));
                        end
                        cum_probs = cumsum(Ptmp);
                        state_tmp = find(cum_probs>rand(),1);
                        Q(i_s,anchor_actions(i_a)) = Q(i_s,anchor_actions(i_a)) + gamma^(i_tau-1)*( r(state_tmp,policy(state_tmp))+ randn()*std_noise_reward );                    
                    end
                end
  
            end
        end
        Q = Q/N_samples;
        Q(anchor_states,anchor_actions) = Q(anchor_states,anchor_actions)/2;

        if curr_epoch == 1
            epsilon_parameter = 1e-3;
        else
            epsilon_parameter = leveraged_list_inf(i_experiment,curr_epoch-1)*1e-3;
        end
        my_pinv_mat = my_pinv(epsilon_parameter, D*Q(anchor_states,anchor_actions)*DW);
    
    
        for i_s = 1:S
            if any(i_s==anchor_states)
                continue;
            end
            for i_a = 1:A
                if any(i_a==anchor_actions)
                    continue;
                end
                Q(i_s,i_a) = Q(i_s,anchor_actions)*DW*my_pinv_mat*D*Q(anchor_states,i_a);
            end
        end

        for i = 1:S
            [~,policy(i)] = max(Q(i,:));
        end               
        
        leveraged_list_inf(i_experiment,curr_epoch) = max(abs(Qstar(:) - Q(:)));
        samples_epoch_leveraged(curr_epoch) = 2*N_samples*deff*(S+A);
        leveraged_list_fro(i_experiment,curr_epoch) = norm(Qstar - Q,'fro');
    end

    
    %% UNIFORM PI ITERATION
    policy = policy0;
    samples_epoch_uniform = zeros([1,N_steps]);
    
    for curr_epoch = 1:N_steps
        disp([2,i_experiment,curr_epoch])
        N_samples = ceil(samples_epoch_leveraged(curr_epoch)/(deff*(S+A)));
        anchor_states = randperm(S,deff);
        anchor_actions = randperm(A,deff);
    
        Q = zeros([S,A]);
        for i_s = 1:length(anchor_states)
            disp(i_s)
            for i_a = 1:A
                for i_N = 1:N_samples
                    state_tmp = anchor_states(i_s);
                    action_tmp = i_a;
                    Q(anchor_states(i_s),i_a) = Q(anchor_states(i_s),i_a) + r(state_tmp,action_tmp) + randn()*std_noise_reward;
                    for i_tau = 2:tau
                        if i_tau==2
                            Ptmp = squeeze(P(state_tmp,action_tmp,:));
                        else
                            Ptmp = squeeze(P(state_tmp,policy(state_tmp),:));
                        end
                        cum_probs = cumsum(Ptmp);
                        state_tmp = find(cum_probs>rand(),1);
                        Q(anchor_states(i_s),i_a) = Q(anchor_states(i_s),i_a) + gamma^(i_tau-1)*( r(state_tmp,policy(state_tmp))+ randn()*std_noise_reward );                    
                    end
                end
            end
        end
        for i_a = 1:length(anchor_actions)
            disp(i_a)
            for i_s = 1:S
                for i_N = 1:N_samples
                    state_tmp = i_s;
                    action_tmp = anchor_actions(i_a);
                    Q(i_s,anchor_actions(i_a)) = Q(i_s,anchor_actions(i_a)) + r(state_tmp,action_tmp) + randn()*std_noise_reward;
                    for i_tau = 2:tau
                        if i_tau==2
                            Ptmp = squeeze(P(state_tmp,action_tmp,:));
                        else
                            Ptmp = squeeze(P(state_tmp,policy(state_tmp),:));
                        end
                        cum_probs = cumsum(Ptmp);
                        state_tmp = find(cum_probs>rand(),1);
                        Q(i_s,anchor_actions(i_a)) = Q(i_s,anchor_actions(i_a)) + gamma^(i_tau-1)*( r(state_tmp,policy(state_tmp))+ randn()*std_noise_reward );                    
                    end
                end
  
            end
        end
        Q = Q/N_samples;
        Q(anchor_states,anchor_actions) = Q(anchor_states,anchor_actions)/2;
    
        if curr_epoch == 1
            epsilon_parameter = 1e-3;
        else
            epsilon_parameter = uniform_list_inf(i_experiment,curr_epoch-1)*1e-3;
        end
        my_pinv_mat = my_pinv(epsilon_parameter, Q(anchor_states,anchor_actions));
    
    
        for i_s = 1:S
            if any(i_s==anchor_states)
                continue;
            end
            for i_a = 1:A
                if any(i_a==anchor_actions)
                    continue;
                end
                Q(i_s,i_a) = Q(i_s,anchor_actions)*my_pinv_mat*Q(anchor_states,i_a);
            end
        end

        for i = 1:S
            [~,policy(i)] = max(Q(i,:));
        end  

        uniform_list_inf(i_experiment,curr_epoch) = max(abs(Qstar(:) - Q(:)));
        samples_epoch_uniform(curr_epoch) = N_samples*deff*(S+A);
        uniform_list_fro(i_experiment,curr_epoch) = norm(Qstar - Q,'fro');
    end
    
    
    
    
    %% VANILLA PI ITERATION
    policy = policy0;
    samples_epoch_vanilla = zeros([1,N_steps]);
    
    for curr_epoch = 1:N_steps
        disp([3,i_experiment,curr_epoch])
        Q = zeros([S,A]);
        N_samples = ceil(samples_epoch_leveraged(curr_epoch)/(S*A));
        for i_s = 1:S
            disp(i_s)
            for i_a = 1:A
                % Learning part
                for i_N = 1:N_samples
                    state_tmp = i_s;
                    action_tmp = i_a;
                    Q(i_s,i_a) = Q(i_s,i_a) + r(state_tmp,action_tmp) + randn()*std_noise_reward;
                    for i_tau = 2:tau
                        if i_tau==2
                            Ptmp = squeeze(P(state_tmp,action_tmp,:));
                        else
                            Ptmp = squeeze(P(state_tmp,policy(state_tmp),:));
                        end
                        cum_probs = cumsum(Ptmp);
                        state_tmp = find(cum_probs>rand(),1);
                        Q(i_s,i_a) = Q(i_s,i_a) + gamma^(i_tau-1)*( r(state_tmp,policy(state_tmp))+ randn()*std_noise_reward );                    
                    end
                end
            end
        end
        Q = Q/N_samples;

        for i = 1:S
            [~,policy(i)] = max(Q(i,:));
        end        
        vanilla_list_inf(i_experiment,curr_epoch) = max(abs(Qstar(:) - Q(:)));
        samples_epoch_vanilla(curr_epoch) = N_samples*S*A;
        vanilla_list_fro(i_experiment,curr_epoch) = norm(Qstar - Q,'fro');
    end


end


%%
mean_uniform_inf = mean(uniform_list_inf);
mean_leveraged_inf = mean(leveraged_list_inf);
mean_vanilla_inf = (vanilla_list_inf);

mean_uniform_fro = mean(uniform_list_fro);
mean_leveraged_fro = mean(leveraged_list_fro);
mean_vanilla_fro = (vanilla_list_fro);

std_uniform_inf = std(uniform_list_inf);
std_leveraged_inf = std(leveraged_list_inf);
std_vanilla_inf = 0;

std_uniform_fro = std(uniform_list_fro);
std_leveraged_fro = std(leveraged_list_fro);
std_vanilla_fro = 0;


%% Entrywise plot
figure()
set(gcf,'renderer','Painters')
x = cumsum(samples_epoch_leveraged);
curve1 = mean_leveraged_inf + std_leveraged_inf;
curve2 = mean_leveraged_inf - std_leveraged_inf;
x2 = [x, fliplr(x)];
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'b','FaceAlpha',0.05,'LineStyle','none');
hold on

x = cumsum(samples_epoch_uniform);
curve1 = mean_uniform_inf + std_uniform_inf;
curve2 = mean_uniform_inf - std_uniform_inf;
x2 = [x, fliplr(x)];
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'r','FaceAlpha',0.05,'LineStyle','none');
hold on

x = cumsum(samples_epoch_vanilla);
curve1 = mean_vanilla_inf + std_vanilla_inf;
curve2 = mean_vanilla_inf - std_vanilla_inf;
x2 = [x, fliplr(x)];
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'g','FaceAlpha',0.05,'LineStyle','none');
hold on

plot(cumsum(samples_epoch_leveraged),mean_leveraged_inf);
hold on
plot(cumsum(samples_epoch_uniform),mean_uniform_inf)
hold on
plot(cumsum(samples_epoch_vanilla),mean_vanilla_inf);
hold on
ylim([0,0.25])
ylabel("$\Vert \widehat{Q}^{(t)} - Q^\star \Vert_{\infty}$", 'interpreter', 'latex','FontSize',16)
xlabel("total number of trajectories",'interpreter','latex','FontSize',12)
legend("","","","PI with leveraged anchors","PI with uniform anchors","full-matrix PI",'interpreter','latex','FontSize',12)
set(gca,'TickLabelInterpreter','latex')


%% Frobenius plot
figure()
set(gcf,'renderer','Painters')
x = cumsum(samples_epoch_leveraged);
curve1 = mean_leveraged_fro + std_leveraged_fro;
curve2 = mean_leveraged_fro - std_leveraged_fro;
x2 = [x, fliplr(x)];
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'b','FaceAlpha',0.05,'LineStyle','none');
hold on

x = cumsum(samples_epoch_uniform);
curve1 = mean_uniform_fro + std_uniform_fro;
curve2 = mean_uniform_fro - std_uniform_fro;
x2 = [x, fliplr(x)];
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'r','FaceAlpha',0.05,'LineStyle','none');
hold on

x = cumsum(samples_epoch_vanilla);
curve1 = mean_vanilla_fro + std_vanilla_fro;
curve2 = mean_vanilla_fro - std_vanilla_fro;
x2 = [x, fliplr(x)];
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'g','FaceAlpha',0.05,'LineStyle','none');
hold on

plot(cumsum(samples_epoch_leveraged),mean_leveraged_fro);
hold on
plot(cumsum(samples_epoch_uniform),mean_uniform_fro)
hold on
plot(cumsum(samples_epoch_vanilla),mean_vanilla_fro);
ylim([0,11])
ylabel("$\Vert \widehat{Q}^{(t)} - Q^\star \Vert_{\mathrm{F}}$", 'interpreter', 'latex','FontSize',16)
xlabel("total number of trajectories",'interpreter','latex','FontSize',12)
legend("","","","PI with leveraged anchors","PI with uniform anchors","full-matrix PI",'interpreter','latex','FontSize',12)
set(gca,'TickLabelInterpreter','latex')

