gamma = 0.87;
r = [-0.46 -0.48; -0.14 0.28];
P = zeros([2,2,2]);
P(:,:,1) = [0.4 0.6; 0.15 0.85];
P(:,:,2) = [0.25 0.75; 0.29 0.71];
Vzero = [2.8583 2.9751];

%%
policies = [1 1; 1 2; 2 1; 2 2];

V = zeros([2,4]);
kappa = zeros([1,4]);
tol = 1e-10;
for i = 1:4
    Qprev = zeros([2,2]);
    not_conv = 1;
    while(not_conv)
        Qtmp = r;
        for s = 1:2
            Qtmp = Qtmp + gamma*squeeze(P(:,s,:))*Qprev(s,policies(i,s));
        end
        for s = 1:2
            V(s,i) = max(Qtmp(s,:));
        end
        if max(abs(Qprev(:)-Qtmp(:))) < tol
            not_conv = 0;
        end
        Qprev = Qtmp;
    end
    sigmastmp = svd(Qprev);
    kappa(i) = sigmastmp(1)/sigmastmp(2);
end



%% Value iteration
N = 1000;
Viter = zeros([2,N]);
kappaiter = zeros([N,1]);
tol = 1e-10;
Vprev = Vzero;
not_conv = 1;
cnt = 0;
while(not_conv)
    cnt = cnt + 1;
    Qtmp = r;
    for s = 1:2
        Qtmp = Qtmp + gamma*squeeze(P(:,s,:))*Vprev(s);
    end
    sigmastmp = svd(Qtmp);
    kappaiter(cnt) = sigmastmp(1)/sigmastmp(2);
    for s = 1:2
        Viter(s,cnt) = max(Qtmp(s,:));
    end
    if max(abs(Vprev(:)-Viter(:,cnt))) < tol
        not_conv = 0;
    end
    Vprev = Viter(:,cnt);
end
kappaiter = kappaiter(1:cnt);
Viter = Viter(:,1:cnt);

Nx = 1000;
Ny = 1000;

xVmin = min(min(V(1,:)),min(Viter(1,:)))*1.5;
xVmax = max(max(V(1,:)),max(Viter(1,:)))*1.5;
yVmin = min(min(V(2,:)),min(Viter(2,:)))*1.5;
yVmax = max(max(V(2,:)),max(Viter(2,:)))*1.5;

xspace = linspace(xVmin, xVmax, Nx);
yspace = linspace(yVmin, yVmax, Ny);
kappamap = zeros([Nx,Ny]);

for i = 1:Nx
    for j = 1:Ny
            Qtmp = r + gamma*squeeze(P(:,1,:))*xspace(i) + gamma*squeeze(P(:,2,:))*yspace(j);
            sigmastmp = svd(Qtmp);
            kappamap(i,j) = sigmastmp(1)/sigmastmp(2);
    end
end

Viter_ext = [Vzero' Viter];


%% Plotting
Ncolor = 1000;
colr_linspace = linspace(0.1,1,Ncolor);
C = zeros([Ncolor,3]);
Cbase = [0, 447/1000, 741/1000];
for i = 1:length(colr_linspace)
    C(i,:) = 1-(1-Cbase)*colr_linspace(i);
end
orange=[0.8500 0.3250 0.0980];

figure()
imagesc(xspace,yspace,log(kappamap'))
set(gca,'YDir','normal')
colormap(C)
hcolorbar = colorbar();
set(hcolorbar,'TickLabelInterpreter','latex')
% axis off;
hold on
scatter(Viter_ext(1,:),Viter_ext(2,:),30,orange,'diamond','filled')
hold on
scatter(V(1,:),V(2,:),100,'k',"x",'linewidth',2);

xlabel("$V(s_1)$",'interpreter','latex','FontSize',18);
ylabel("$V(s_2)$",'interpreter','latex','FontSize',18);
legend('$V^{(t)}$','$V^{\pi}$','interpreter','latex','FontSize',18)
ylabel(hcolorbar, "$\log$(condition number)",'interpreter','latex','FontSize',18)
text(0.07,0.45,"$V^\star$",'interpreter','latex','FontSize',14);
text(2.65,3.15,"$V^{(0)}$",'interpreter','latex','FontSize',14);
text(1.9,3,"$V^{(1)}$",'interpreter','latex','FontSize',14);
set(gca,'TickLabelInterpreter','latex')