%% environment setup
close all force; clear; clc;
plot_rays = false;
visualize = false;

mapFileName = "models/conference.stl"; % conference.stl, bedroom.stl, office.stl
[stl_data, ~] = stlread(mapFileName);

if visualize
    viewer = siteviewer("SceneModel", mapFileName, "Transparency", 0.25);
end

% environment dimensions setup
vertices = stl_data.Points;
faces = stl_data.ConnectivityList;

xy_offset = 0.1;
z_offset = 0.1;
min_z = max(0, min(vertices(:, 3)));
env_dims = [
            [min(vertices(:, 1)) + xy_offset, max(vertices(:, 1)) - xy_offset];
            [min(vertices(:, 2)) + xy_offset, max(vertices(:, 2)) - xy_offset];
            [min_z, max(vertices(:, 3)) - z_offset]
            ];

% point cloud generation params
pc_params = struct();
pc_params.edge_density = 2.1;
pc_params.surface_density = 1.6;
pc_params.volume_density = 0;
pc_params.boundary_density = 87.3;
pc_params.random_points = 0;
pc_params.noise_std = 0;
pc_params.edge_reduction = 1;
pc_params.surface_reduction = 1;

%% generate point cloud
point_cloud = generate_pc(vertices, faces, pc_params, env_dims, visualize);

%% system config
fc = 5e9;
lambda = physconst("lightspeed") / fc;

% OFDM parameters
cfg = wlanNonHTConfig;
cfg.ChannelBandwidth = 'CBW80';

% extra config
use_single_sc = true;
sc_idx = [];
use_siso = false;

if use_siso
    % single-input single-output
    txArray = phased.IsotropicAntennaElement();
    rxArray = phased.IsotropicAntennaElement();

    num_tx_ant = 1;
    num_rx_ant = 1;
else
    % multiple-input multiple-output
    txArray = phased.URA("Size", [4 2], "ElementSpacing", lambda / 2);
    rxArray = phased.ULA("NumElements", 2, "ElementSpacing", lambda / 2);

    num_tx_ant = prod(txArray.Size);
    num_rx_ant = rxArray.NumElements;
end

%% AP setup
AP = txsite("cartesian", ...
    "Antenna", txArray, ...
    "AntennaPosition", [-1.5; 0.0; 1.7], ... % conference::[-1.5; 0.0; 1.7], bedroom::[-1.5; 0.0; 2.7], office::[0.15; 0.2; 2.8]
    "TransmitterFrequency", fc, ...
    "TransmitterPower", 0.05);

%% user setup
approx_target_users = 452;

% seed
S = RandStream("mt19937ar", "Seed", 17);
RandStream.setGlobalStream(S);

user_params = struct();
user_params.check_building_collision = false;
user_params.check_user_collision = true;
[Users, actual_users] = create_users(env_dims, approx_target_users, rxArray, user_params, []);

if actual_users < approx_target_users
    error('Failed to create all requested users. Only created %d out of %d users.', actual_users, approx_target_users);
end

%% RT simulation
method = "sbr"; % "image" | "sbr"
max_refs = 1;

pm = propagationModel("raytracing", ...
    "Method", method, ...
    "CoordinateSystem", "cartesian", ...
    "SurfaceMaterial", "wood", ...
    "TerrainMaterial", "wood", ...
    "MaxNumReflections", max_refs, ...
    "UseGPU", "auto");

rays = raytrace(AP, Users, pm, "Map", mapFileName);

% filter users to keep only those with valid rays
valid_user_mask = ~cellfun(@isempty, rays);
Users = Users(valid_user_mask);
rays = rays(valid_user_mask);
num_users = sum(valid_user_mask);

if num_users < approx_target_users
    warning('Only %d out of %d users had valid rays, discarding the rest.', ...
        num_users, approx_target_users);
end

%% visualize
if visualize
    show(AP, "ShowAntennaHeight", false)
    show(Users, "ShowAntennaHeight", false)

    if plot_rays

        for userIdx = 1:(num_users / 20) % ~20 % rays
            plot(rays{userIdx}, "Colormap", jet)
            pause(0.05)
        end

    end

end

%% extract positions
rx_positions = zeros(3, num_users);

for userIdx = 1:num_users
    rx_positions(:, userIdx) = Users(userIdx).AntennaPosition;
end

%% CSI collection
ofdmInfo = wlanNonHTOFDMInfo('L-LTF', cfg.ChannelBandwidth);
activeIndices = ofdmInfo.ActiveFrequencyIndices;

if use_single_sc

    if isempty(sc_idx)
        sc_idx = ceil(length(activeIndices) / 2);
    end

    numSubcarriers = 1;
    sc_spacing = wlanSampleRate(cfg.ChannelBandwidth) / ofdmInfo.FFTLength;
    freqs = fc + activeIndices(sc_idx) * sc_spacing;
else
    numSubcarriers = length(activeIndices);
    sc_spacing = wlanSampleRate(cfg.ChannelBandwidth) / ofdmInfo.FFTLength;
    freqs = fc + activeIndices * sc_spacing;
end

H = zeros(num_users, num_tx_ant, num_rx_ant, numSubcarriers);
AoD_all = cell(num_users, 1);
AoA_all = cell(num_users, 1);
path_loss = zeros(num_users, 1);
path_loss_per_ray = cell(num_users, 1);

for userIdx = 1:num_users
    [H(userIdx, :, :, :), AoD_all{userIdx}, AoA_all{userIdx}] = ...
        generate_csi(rays{userIdx}, fc, cfg, txArray, rxArray, method, 'indoor', use_single_sc, sc_idx);
    path_loss(userIdx) = mean([rays{userIdx}.PathLoss]);
    path_loss_per_ray{userIdx} = [rays{userIdx}.PathLoss];
end

% check for null values in channel matrix
if any(isnan(H(:)))
    warning('Channel matrix contains NaN values!');
end

%% ray marching and per-ray information (for NeWRF comparison)
ray_steps = cell(num_users, 1);
ray_points = cell(num_users, 1);
ray_interactions = cell(num_users, 1);
ray_coefficients = cell(num_users, 1);

for userIdx = 1:num_users
    [ray_steps{userIdx}, ray_points{userIdx}] = ray_marching(rays{userIdx});
    [ray_interactions{userIdx}, ray_coefficients{userIdx}] = get_ray_chan(rays{userIdx}, freqs, method);
end

% check for null values in ray marching results
if any(cellfun(@(x) any(cellfun(@(y) any(isnan(y(:))), x)), ray_steps)) || ...
        any(cellfun(@(x) any(cellfun(@(y) any(isnan(y(:))), x)), ray_points)) || ...
        any(cellfun(@(x) any(isnan(x(:))), ray_coefficients))
    warning('Ray marching results contain NaN values!');
end

%% save
output_dir = "outputs";
mkdir(output_dir);

[~, mapname] = fileparts(mapFileName);

% create dataset
dataset = struct();

dataset.config.tx_antennas = num_tx_ant;
dataset.config.rx_antennas = num_rx_ant;
dataset.config.frequency = fc;
dataset.config.wavelength = lambda;
dataset.config.num_users = num_users;
dataset.config.use_siso = use_siso;

dataset.environment.dimensions = env_dims;
dataset.environment.point_cloud = point_cloud;
dataset.environment.pc_params = pc_params;

dataset.nodes.ap_position = AP.AntennaPosition';
dataset.nodes.users_positions = rx_positions;

dataset.channel.H = H;
% NOTE: AoD is unneeded as it is primarily related to the txsite
% dataset.channel.AoD = AoD_all;
dataset.channel.path_loss = path_loss;

% truncate data (max 10 paths)
[AoA_trunc, AoD_trunc, path_loss_per_ray_trunc] = truncate_data(AoA_all, AoD_all, path_loss_per_ray, 10);
dataset.channel.AoA = AoA_trunc;
dataset.channel.AoD = AoD_trunc;
dataset.channel.path_loss_per_ray = path_loss_per_ray_trunc;

dataset.channel.ray_steps = ray_steps;
dataset.channel.ray_points = ray_points;
dataset.channel.ray_interactions = ray_interactions;
dataset.channel.ray_coefficients = ray_coefficients;
dataset.channel.frequencies = freqs;

sc_str = '';

if use_single_sc
    sc_str = sprintf('_sc%d', sc_idx);
end

map_prefix = mapname(1:min(4, length(mapname)));

filename = sprintf('%s/%s_%dx%d_%du_%.1fghz_%sRT%s.mat', ...
    output_dir, ...
    map_prefix, ...
    num_tx_ant, ...
    num_rx_ant, ...
    num_users, ...
    fc / 1e9, ...
    method, ...
    sc_str);

save(filename, 'dataset', '-v7.3');

function [AoA_trunc, AoD_trunc, path_loss_trunc] = truncate_data(AoA, AoD, path_loss, max_paths)
    num_users = length(AoA);
    AoA_trunc = cell(num_users, 1);
    AoD_trunc = cell(num_users, 1);
    path_loss_trunc = cell(num_users, 1);
    
    for i = 1:num_users
        if ~isempty(AoA{i})
            % sort paths by path loss
            [sorted_pl, sort_idx] = sort(path_loss{i});
            sorted_AoA = AoA{i}(:, sort_idx);
            sorted_AoD = AoD{i}(:, sort_idx);
            
            % take top max_paths with lowest path loss
            num_paths = min(length(sorted_pl), max_paths);
            AoA_trunc{i} = sorted_AoA(:, 1:num_paths);
            AoD_trunc{i} = sorted_AoD(:, 1:num_paths);
            path_loss_trunc{i} = sorted_pl(1:num_paths);
        else
            AoA_trunc{i} = [];
            AoD_trunc{i} = [];
            path_loss_trunc{i} = [];
        end
    end
end
