clear; warning off; clc;
rng('default'); rng(0, 'twister');

setup = struct( ...
    'dataset_id',       18, ...     
    'is_timeseries',    true, ...
    'lookback',         96, ...
    'horizon',          96, ...
    'target_feature',   'OT', ...
    'learning',         'LSE', ...
    'OF',               'MSE', ...
    'eval_horizon',     'ALL', ...
    'max_pairs',        1000, ...
    'use_single',       true, ...
    'distance_type',    1, ...
    'poly_type',        2, ...
    'inputs_per_node',  2, ...
    'seed',             0, ...
    'internal_val_ratio', 0.2, ...
    'ZO',               30, ...
    'max_layers',      3);


param_grid = struct( ...
    'No_C',    [3], ...
    'ex_w',    [2.0], ...
    'L',       [0.01], ...
    'k_range', {{[0.2, 2.0]}});


[X, Y, dataset_name] = data_rd_all4m(setup.dataset_id, setup);
[n_samples, n_features] = size(X);
H = size(Y, 2);

fprintf('\n%s\n', repmat('=', 1, 90));
fprintf(' Dataset: %s | Samples: %d | Features: %d | H: %d | max_layers: %d\n', ...
    dataset_name, n_samples, n_features, H, setup.max_layers);
fprintf('%s\n', repmat('=', 1, 90));


if contains(lower(dataset_name), "ett")
    split_ratio = [6, 2, 2];
else
    split_ratio = [7, 1, 2];
end

idx_split = round(cumsum(split_ratio) / sum(split_ratio) * n_samples);

X_train = X(1:idx_split(1), :);
X_val   = X(idx_split(1)+1:idx_split(2), :);
X_test  = X(idx_split(2)+1:end, :);

Y_train = Y(1:idx_split(1), :);
Y_val   = Y(idx_split(1)+1:idx_split(2), :);
Y_test  = Y(idx_split(2)+1:end, :);

train_size = size(X_train, 1);
val_size   = size(X_val, 1);
test_size  = size(X_test, 1);

fprintf(' Split %d:%d:%d -> Train=%d | Val=%d | Test=%d\n', ...
    split_ratio, train_size, val_size, test_size);


y_mu    = mean(Y_train, 1);
y_sigma = std(Y_train, 0, 1);
y_sigma(y_sigma < eps) = 1;

Y_train_norm = (Y_train - y_mu) ./ y_sigma;
Y_val_norm   = (Y_val   - y_mu) ./ y_sigma;
Y_test_norm  = (Y_test  - y_mu) ./ y_sigma;

fprintf(' Y norm: mu(mean)=%.4f, sigma(mean)=%.4f\n', mean(y_mu), mean(y_sigma));


lens = [numel(param_grid.No_C), numel(param_grid.ex_w), ...
        numel(param_grid.L), numel(param_grid.k_range)];
[g1, g2, g3, g4] = ndgrid(1:lens(1), 1:lens(2), 1:lens(3), 1:lens(4));
n_configs = numel(g1);

fprintf(' Grid Search: %d configs × %d layers\n\n', n_configs, setup.max_layers);


headers = {'ID', 'No_C', 'ex_w', 'max_layers', 'L', 'k_min', 'k_max', ...
           'BestLayer', 'Val_MSE', 'Val_MAE', 'Time', 'Status'};
summary_results = [headers; cell(n_configs, numel(headers))];

best_val_mse = inf;
best_config = [];

tic_total = tic;

for idx = 1:n_configs
    cfg = setup;
    cfg.No_C    = param_grid.No_C(g1(idx));
    cfg.ex_w    = param_grid.ex_w(g2(idx));
    cfg.L       = param_grid.L(g3(idx));
    cfg.k_range = param_grid.k_range{g4(idx)};

    fprintf('[%d/%d] No_C=%d, ex_w=%.1f, L=%.4f, k=[%.2f,%.2f] ... ', ...
        idx, n_configs, cfg.No_C, cfg.ex_w, cfg.L, cfg.k_range(1), cfg.k_range(2));

    row = idx + 1;
    summary_results(row, 1:7) = {idx, cfg.No_C, cfg.ex_w, cfg.max_layers, ...
                                  cfg.L, cfg.k_range(1), cfg.k_range(2)};

    try
        t = tic;
       [~, ~, Y_val_pred_norm, ~, ~, ~, info]=   sub_fpnn_fcm_fourier( ...
            cfg, cfg.max_layers, ...
            X_train, Y_train_norm, ...
            X_val, Y_val_norm, ...
            cfg.inputs_per_node);

        elapsed = toc(t);


        n_layers = size(Y_val_pred_norm, 3);
        val_mse_per_layer = zeros(1, n_layers);

        for L = 1:n_layers
            r = Y_val_norm - Y_val_pred_norm(:, :, L);
            val_mse_per_layer(L) = mean(r(:).^2);
        end

        [best_layer_mse, best_layer] = min(val_mse_per_layer);


        r_best = Y_val_norm - Y_val_pred_norm(:, :, best_layer);
        best_layer_mae = mean(abs(r_best(:)));

        fprintf('OK (%.1fs) Layer=%d Val_MSE=%.6f\n', elapsed, best_layer, best_layer_mse);

        summary_results(row, 8:12) = {best_layer, best_layer_mse, best_layer_mae, elapsed, 'OK'};


        if best_layer_mse < best_val_mse
            best_val_mse = best_layer_mse;
            best_config = cfg;
            best_config.best_layer = best_layer;
            best_config.val_mse = best_layer_mse;
            best_config.val_mae = best_layer_mae;
        end

    catch ME
        fprintf('FAILED: %s\n', ME.message);
        summary_results(row, 8:12) = {NaN, NaN, NaN, toc(t), ME.message};
    end
end

fprintf('\n Grid search done: %.1f min\n', toc(tic_total)/60);

if isempty(best_config)
    error('All configurations failed!');
end

fprintf('\n Best: No_C=%d, ex_w=%.1f, L=%.4f, k=[%.2f,%.2f], Layer=%d, Val_MSE=%.6f\n', ...
    best_config.No_C, best_config.ex_w, best_config.L, ...
    best_config.k_range(1), best_config.k_range(2), ...
    best_config.best_layer, best_config.val_mse);


fprintf('\n%s\n', repmat('=', 1, 90));
fprintf(' Final Test (using Layer %d from Val)\n', best_config.best_layer);
fprintf('%s\n', repmat('=', 1, 90));

X_trainval = [X_train; X_val];
Y_trainval_norm = [Y_train_norm; Y_val_norm];

tic;
[~, ~, Y_test_pred_norm, ~, ~, ~, info_best] = sub_fpnn_fcm_fourier( ...
    best_config, best_config.max_layers, ...
    X_trainval, Y_trainval_norm, ...
    X_test, [], ...
    best_config.inputs_per_node);

for l = 1:best_config.max_layers
    if ~isfield(info_best,'layers') || l > numel(info_best.layers) || isempty(info_best.layers(l).type)
        break;
    end
    Linfo = info_best.layers(l);
    fprintf('\n=== Layer %d (BEST CONFIG) ===\n', l);
    fprintf('type=%s | full=%d | cand=%d | sampled=%d\n', Linfo.type, Linfo.num_full, Linfo.num_cand, Linfo.sampled);
    fprintf('best_idx_local=%d | best_idx_global=%d\n', Linfo.best_idx_local, Linfo.best_idx_global);
    fprintf('best_pair=[%d %d]\n', Linfo.best_pair(1), Linfo.best_pair(2));

    P = Linfo.best_params;
    disp(P.V_norm);
    disp(P.omega);
end
test_time = toc;

fprintf(' Retrain time: %.2f s\n\n', test_time);

n_layers_te = size(Y_test_pred_norm, 3);
use_layer = min(best_config.best_layer, n_layers_te);

Yp_norm = Y_test_pred_norm(:, :, use_layer);

r_test = Y_test_norm - Yp_norm;
test_mse = mean(r_test(:).^2);
test_mae = mean(abs(r_test(:)));

fprintf(' Test (Layer %d): MSE=%.6f | MAE=%.6f\n', use_layer, test_mse, test_mae);


n_layers = size(Y_test_pred_norm, 3);
test_results = [{'Layer', 'Test_MSE', 'Test_MAE', 'Note'}; cell(n_layers, 4)];

for L = 1:n_layers
    YpL = Y_test_pred_norm(:, :, L);
    r_L = Y_test_norm - YpL;
    mseL = mean(r_L(:).^2);
    maeL = mean(abs(r_L(:)));

    note = '';
    if L == use_layer
        note = 'Val-Best (Used)';
    end

    test_results(L+1, :) = {L, mseL, maeL, note};
end

excel_file = sprintf('Results_%s_H%d_L%d_%s.xlsx', ...
    dataset_name, H, setup.max_layers, datestr(now, 'yyyymmdd_HHMMSS'));

writecell(summary_results, excel_file, 'Sheet', '1_GridSearch');

best_info = {
    'Parameter',        'Value';
    'Dataset',          dataset_name;
    'Horizon(H)',       H;
    'Split_Ratio',      sprintf('%d:%d:%d', split_ratio);
    'Train/Val/Test',   sprintf('%d/%d/%d', train_size, val_size, test_size);
    'max_layers',       setup.max_layers;
    '', '';
    'Best Config', '';
    'No_C',             best_config.No_C;
    'ex_w',             best_config.ex_w;
    'L',                best_config.L;
    'k_range',          sprintf('[%.2f, %.2f]', best_config.k_range);
    'Best_Layer (Val)', best_config.best_layer;
    'Val_MSE',          best_config.val_mse;
    'Val_MAE',          best_config.val_mae;
    '', '';
    'Test Results', '';
    'Used_Layer',       use_layer;
    'Test_MSE',         test_mse;
    'Test_MAE',         test_mae;
    'Test_Time(s)',     test_time;
};
writecell(best_info, excel_file, 'Sheet', '2_BestConfig');
writecell(test_results, excel_file, 'Sheet', '3_TestAllLayers');

fprintf('\n Saved: %s\n', excel_file);
fprintf('%s\n', repmat('=', 1, 90));