#pragma once

#include <algorithm>
#include <cmath>
#include <memory>
#include <vector>

#include "config.h"
#include "dataset.h"

// -----------------------------
// Distance matrix
// -----------------------------
class TriMatrix {
public:
    explicit TriMatrix(int n, float init = 0.0f)
        : n_(n), data_(static_cast<std::size_t>(n) * (n - 1) / 2, init), offset_(n) {
        for (int i = 0; i < n_; ++i) {
            offset_[i] = (n_ - 1) * i - (i * (i - 1)) / 2;
        }
    }

    float get(int i, int j) const {
        if (i == j) {
            return 0.0f;
        }
        if (i > j) {
            std::swap(i, j);
        }
        return data_[index(i, j)];
    }

    void set(int i, int j, float value) {
        if (i == j) {
            return;
        }
        if (i > j) {
            std::swap(i, j);
        }
        data_[index(i, j)] = value;
    }

    void fill(float value) {
        std::fill(data_.begin(), data_.end(), value);
    }

    std::vector<float> &raw() { return data_; }

private:
    int n_;
    std::vector<float> data_;
    std::vector<int> offset_;

    int index(int i, int j) const {
        return offset_[i] + (j - i - 1);
    }
};

class DistanceProvider {
public:
    DistanceProvider(const Instance &inst, const Config &cfg);

    double distance(int a, int b) const;
    double distance_sq(int a, int b) const;
    double inv_dist_beta(int a, int b, double beta) const;

private:
    const Instance &inst_;
    bool precompute_ = false;
    std::unique_ptr<TriMatrix> dist_;
};
