function obj = readTransformerNetwork(file_path, varargin)
% readTransformerNetwork - reads and converts a transformer network saved in nnet format
%
% Syntax:
%    res = neuralNetwork.readTransformerNetwork(file_path, verbose)
%
% Inputs:
%    file_path - path to file(information of # of layers, bias, weight...'features')
%    verbose - bool if information should be displayed
%
% Outputs:
%    obj - generated object
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork/readONNXNetwork

% ------------------------------ BEGIN CODE -------------------------------

% parse input
if nargin < 1
    throw(CORAerror('CORA:notEnoughInputArgs',1))
elseif nargin > 2
    throw(CORAerror('CORA:tooManyInputArgs',2));
end
verbose = setDefaultValues({false}, varargin);
inputArgsCheck({ ...
    {file_path, 'att', {'char', 'string'}}; ...
    {verbose, 'att', 'logical'}; ...
})
if ~isfile(file_path)
    throw(CORAerror('CORA:fileNotFound', file_path));
end

% prepare and setup the model file containing weights and biases
fid = fopen(file_path);
raw = fread(fid, inf);
str = char(raw');
fclose(fid);
model = jsondecode(str);

% layers: cell array containing different types of layers
layers = {};

% token and positional embedding layer
token_emb_weights = double(model.token_embedding_weights);
pos_emb_weights = double(model.positional_embedding_weights);
layers{end+1} = nnTransformerEmbeddingLayer(token_emb_weights, pos_emb_weights);

if contains(file_path, "2b")  || contains(file_path, "yelp")
    % transformer blocks
    num_blocks = length(model.transformer_blocks);
    for i = 1:num_blocks
        sa_weights = model.transformer_blocks(i).self_attention;
        
        % multi-head self attention layer
        num_heads = double(sa_weights.num_heads);
        query_weights = double(sa_weights.query_weights);
        key_weights = double(sa_weights.key_weights);
        value_weights = double(sa_weights.value_weights);
        output_weights = double(sa_weights.output_weights);
       
        % layer normalization 1
        ln1 = model.transformer_blocks(i).layernorm1;
        gamma1 = double(ln1.gamma);
        beta1 = double(ln1.beta);
        epsilon = 0.000001;
    
        layers{end+1} = nnTNNResidualConnectionLayer(...
            query_weights, key_weights, value_weights, output_weights, num_heads, beta1, gamma1, ...
            epsilon,size(pos_emb_weights,2), "multihead att", "layernorm1", ...
            ['block', num2str(i), 'residual']);
        
    
        % feed forward network (dense 1)
        ff_weights = model.transformer_blocks(i).feedforward;
        w_1 = (double(ff_weights.dense_1_weights))';
        b_1 = double(ff_weights.dense_1_bias);
        
        % feed forward network (dense 2)
        w_2 = (double(ff_weights.dense_2_weights))';
        b_2 = double(ff_weights.dense_2_bias);
        
        % layer normalization 2
        ln2 = model.transformer_blocks(i).layernorm2;
        gamma2 = double(ln2.gamma);
        beta2 = double(ln2.beta);
     
    
        layers{end+1} = nnFFNResidualConnectionLayer(...
            ['block', num2str(i), 'residual2'], ...
            w_1, b_1, w_2, b_2, beta2, gamma2, ...
            epsilon,size(pos_emb_weights,2), "layernorm2", "lin1", "lin2");
    end
    
    
    layers{end+1} = nnTNNGlobalAveragePoolingLayer(size(pos_emb_weights,2), "global average pooling layer");
    
    
    % final dense layer
    final_dense_weights = (double(model.final_dense_weights.weights))';
    final_dense_bias = double(model.final_dense_weights.bias);
    
    layers{end+1} = nnTNNLinearLayer(final_dense_weights, final_dense_bias, "lin3");
    layers{end+1} = nnSoftmaxLayer();
    
    obj = neuralNetwork(layers);
else

     % transformer blocks
    num_blocks = length(model.self_attention);
    for i = 1:num_blocks
        sa_weights = model.self_attention(i);
        
        % multi-head self attention layer
        num_heads = double(sa_weights.num_heads);
        query_weights = double(sa_weights.query_weights);
        key_weights = double(sa_weights.key_weights);
        value_weights = double(sa_weights.value_weights);
        output_weights = double(sa_weights.output_weights);
       
        % layer normalization 1
        ln1 = model.layernorm1(i);
        gamma1 = double(ln1.gamma);
        beta1 = double(ln1.beta);
        epsilon = 0.000001;
    
        layers{end+1} = nnTNNResidualConnectionLayer(...
            query_weights, key_weights, value_weights, output_weights, num_heads, beta1, gamma1, ...
            epsilon,size(pos_emb_weights,2), "multihead att", "layernorm1", ...
            ['block', num2str(i), 'residual']);
        
    
        % feed forward network (dense 1)
        ff_weights = model.feedforward(i);
        w_1 = (double(ff_weights.dense_1_weights))';
        b_1 = double(ff_weights.dense_1_bias);
        
        % feed forward network (dense 2)
        w_2 = (double(ff_weights.dense_2_weights))';
        b_2 = double(ff_weights.dense_2_bias);
        
        % layer normalization 2
        ln2 = model.layernorm2(i);
        gamma2 = double(ln2.gamma);
        beta2 = double(ln2.beta);
     
    
        layers{end+1} = nnFFNResidualConnectionLayer(...
            ['block', num2str(i), 'residual2'], ...
            w_1, b_1, w_2, b_2, beta2, gamma2, ...
            epsilon,size(pos_emb_weights,2), "layernorm2", "lin1", "lin2");
    end
    
    
    layers{end+1} = nnTNNGlobalAveragePoolingLayer(size(pos_emb_weights,2), "global average pooling layer");
    
    
    % final dense layer
    final_dense_weights = (double(model.final_dense_weights.weights))';
    final_dense_bias = double(model.final_dense_weights.bias);
    
    layers{end+1} = nnTNNLinearLayer(final_dense_weights, final_dense_bias, "lin3");
    layers{end+1} = nnSoftmaxLayer();
    
    obj = neuralNetwork(layers);
end

% print the layers if verbose is true
if verbose
    disp(obj)
end

end


% ------------------------------ END OF CODE ------------------------------
