clc;
clear all;

% Estimation of sparse transition matrix through stocks data from S&P 500 index

addpath(genpath('..\Software_Linear_Convergence_VAR'));

%%%%%Read firm names and their sectors
load('SP500firms.mat')

%%%%%Read firm names and their sectors
load('SP500firms.mat')

%%%%%Read stock data
stockDataFull = csvread('StockPrice.csv', 1, 0);
numFirm = width(stockDataFull);
stockFirmName = readcell('StockPrice.csv', 'Range', [1 1 1 numFirm])';
stockFirmSector = [];

IndicatorFirm = [1 : numFirm];

for nFirm = 1 : numFirm
    if ismember(SP500firms.Symbol, char(stockFirmName{nFirm, 1})) == 0
        IndicatorFirm(nFirm) = 0;
    else
    sectorName = table2array(SP500firms(ismember(SP500firms.Symbol, char(stockFirmName{nFirm, 1})),2));
    stockFirmSector = [stockFirmSector; sectorName];
    end
end

Index = nonzeros(IndicatorFirm);
[group, id] = sort(stockFirmSector);


%%%%%Select firms labeled sector
stockDataSelected = stockDataFull(:, Index);
 
[n, d] = size(stockDataSelected);

%%%%%Log-return
stockLogReturn = log(stockDataSelected([2 : n], :) ./ stockDataSelected([1 : n - 1], :));

%%%%%Construct AR model
X = stockLogReturn([1 : n - 2], :);
Y = stockLogReturn([2 : n - 1], :);

[n, d] = size(X);

%%%%%Initial point
A0 = zeros(d, d);
D = 150;

%%%%%Maximum iteration
kmax = 1000;


%%%%%Estimate with PGD
APGD = PGD_l1ball_Stock_Crossvalidation(X, Y, A0, kmax, D);

disp(['Estimation error ' num2str(norm(Y - X * APGD, 'fro'))]);
disp(['Sparsity ' num2str(nnz(APGD) / d / d)]);

%%%%%Sector of stocks
IndexEnergy = id(group == 'Energy');
IndexIT = id(group == 'Information Technology');
IndexHealthCare = id(group == 'Health Care');
IndexFinancials = id(group == 'Financials');
IndexIndustrials = id(group == 'Industrials');
IndexMaterials = id(group == 'Materials');
IndexUtilities = id(group == 'Utilities');
IndexConsumerDiscretionary = id(group == 'Consumer Discretionary');
IndexConsumerStaples = id(group == 'Consumer Staples');

%%%%%Materials sector and Energy sector
figure(1)
imagesc(APGD([IndexMaterials; IndexEnergy], [IndexMaterials; IndexEnergy]));
colormap(bluewhitered(256))

%%%%%Consumer staples sector and Financials sector
figure(2)
imagesc(APGD([IndexConsumerStaples; IndexFinancials], [IndexConsumerStaples; IndexFinancials]));
colormap(bluewhitered(256))


%%%%%Estimate with FNSL



%%%%%Materials sector and Energy sector

lambda = 0.00015 * sqrt(log(d) * n);

AFNSL = FNSL_Stock_Crossvalidation(Y, X, A0, kmax, lambda);

disp(['Estimation error ', num2str(norm(Y - X * AFNSL, 'fro'))]);
disp(['Sparsity ', num2str(nnz(AFNSL) / d / d)]);

%%%%%Materials sector and Energy sector
figure(3)
imagesc(AFNSL([IndexMaterials; IndexEnergy], [IndexMaterials; IndexEnergy]));
colormap(bluewhitered(256))

%%%%%Consumer staples sector and Financials sector
figure(4)
imagesc(AFNSL([IndexConsumerStaples; IndexFinancials], [IndexConsumerStaples; IndexFinancials]));
colormap(bluewhitered(256))