#include "qid2.h"


void rng_seed(int seed_)
{
    gen.seed(seed_);
}




SQVec ::SQVec()
{
    left = NULL;
    right = NULL;
    sign = 1;
    norm_val = 0;
    index = -1;
    size_val = 0;
    
}

double SQVec ::unif()
{
    uniform_real_distribution<double> uniform_dist(0, 1);
    return uniform_dist(gen);
}

void SQVec ::build_util(vector<double> &arr, int l, int r)
{
    if (l > r)
    {
        return;
    }

    else if (l == r)
    {
        index = l;
        norm_val = arr[l] * arr[l];
        if (arr[l] < 0)
        {
            sign = -1;
        }
        size_val = 1;
    }

    else
    {
        int mid = l + (r - l) / 2;
        left = new SQVec();
        right = new SQVec();
        left->build_util(arr, l, mid);
        right->build_util(arr, mid + 1, r);
        double left_norm = left->norm_val;
        double right_norm = right->norm_val;
        norm_val = left_norm + right_norm;
        size_val = left->size_val + right->size_val;
    }
}

void SQVec ::build(vector<double> &arr)
{
    build_util(arr, 0, arr.size() - 1);
}

int SQVec ::size() const
{
    return size_val;
}

double SQVec ::get(int i) const
{
    if (left == NULL && right == NULL)
    {
        return norm_val * sign;
    }

    if (i < left->size())
    {
        return left->get(i);
    }

    return right->get(i - left->size());
}

double SQVec ::norm2() const
{
    return norm_val;
}

int SQVec ::sample()
{

    if (left == NULL && right == NULL)
    {
        return index;
    }

    double r = unif();
    double p = left->norm2() / norm_val;
    if (r < p)
    {
        return left->sample();
    }

    return right->sample();
}

void SQVec ::update(int i, double val)
{
    if (left == NULL && right == NULL)
    {
        norm_val = abs(val);
        if (val < 0)
        {
            sign = -1;
        }
    }

    else
    {
        if (i < left->size())
        {
            left->update(i, val);
        }

        else
        {
            right->update(i - left->size(), val);
        }

        double left_norm = left->norm_val;
        double right_norm = right->norm_val;
        norm_val = left_norm + right_norm;
    }
}

SQVec ::~SQVec()
{
    if (this != NULL)
    {
        delete left;
        delete right;
    }
}

SQMat ::SQMat()
{
    row_vec = NULL;
    rows = NULL;
    norm_val = 0;
}

void SQMat ::build(vector<vector<double>> &arr)
{
    row_vec = new SQVec();
    rows = new SQVec[arr.size()];
    double sum = 0;
    vector<double> row_norms;
    for (int i = 0; i < arr.size(); i++)
    {
        rows[i].build(arr[i]);
        double norm_i = rows[i].norm2();
        row_norms.push_back(sqrt(norm_i));
        sum += norm_i;
    }
    row_vec->build(row_norms);
    norm_val = sum;
}

int SQMat ::size() const
{
    return row_vec->size();
}

SQMat::~SQMat()
{
    delete row_vec;
    delete[] rows;
}

SQCentre ::SQCentre(SQMat *V, SQVec *c) 
{
    sqV = V;
    sqc = c;
    double normV2 = sqV->row_vec->norm2();
    double normc2 = sqc->norm2();
    p = normV2 / (normV2 + sqV->size() * normc2);
}

double SQCentre ::bound_query(int i) 
{

    double Vinorm = sqV->row_vec->get(i);
    double cnorm = sqc->norm2();

    return sqrt(2 * (Vinorm + cnorm));
}

double SQCentre ::bound_sample() 
{
    double r = unif();
    if (r < p)
    {
        return sqV->row_vec->sample();
    }
    else
    {
        return unif_int(sqV->size());
    }
}

double SQCentre ::get(int i) 
{
    double sum = 0;
    for (int j = 0; j < sqV->rows[0].size(); j++)
    {
        sum += (sqV->rows[i].get(j) - sqc->get(j)) * (sqV->rows[i].get(j) - sqc->get(j));
    }
    return sqrt(sum);
}

int SQCentre ::sample() 
{
    while (true)
    {
        int idx = bound_sample();
        double r = unif();
        if (r < (get(idx) * get(idx)) / (bound_query(idx) * bound_query(idx)))
        {
            return idx;
        }
    }
}

double SQCentre ::unif()
{
    uniform_real_distribution<double> uniform_dist(0, 1);
    return uniform_dist(gen);
}


int SQCentre ::unif_int(int n)
{
    uniform_int_distribution<int> uniform_dist(0, n - 1);
    return uniform_dist(gen);
}


double SQCentres ::unif()
{
    uniform_real_distribution<double> uniform_dist(0, 1);
    return uniform_dist(gen);
}



SQCentres::SQCentres(SQMat *V, vector<int> &idxes) 
{
    sqV = V;
    indices = idxes;
    probs.resize(indices.size());

    for (int i = 0; i < indices.size(); i++)
    {
        probs[i] = sqV->rows[indices[i]].norm2();
    }

    dist = discrete_distribution<int>(probs.begin(), probs.end());

    for (auto idx : indices)
    {
        SQCentre *s = new SQCentre(V, &V->rows[idx]);
        centres.push_back(s);
    }
    k = indices.size();
}

double SQCentres ::get(int i)
{
    double min_dist = 1e9;
    for (int j = 0; j < k; j++)
    {
        double dist = centres[j]->get(i);
        if (dist < min_dist)
        {
            min_dist = dist;
        }
    }
    return min_dist;
}

double SQCentres ::bound_query(int i)
{
    double sum = 0;
    for (int j = 0; j < k; j++)
    {
        double dist = centres[j]->bound_query(i);

        sum += dist * dist;
    }
    return sqrt(sum / k);
}

double SQCentres ::bound_sample()
{
    int cidx = dist(gen);
    return centres[cidx]->bound_sample();
}

int SQCentres ::sample()
{
    while (true)
    {
        int idx = bound_sample();
        double r = unif();
        if (r < (get(idx) * get(idx)) / (bound_query(idx) * bound_query(idx)))
        {
            return idx;
        }
    }
}

vector<int> QIkpp(SQMat* V, int k)
{
    uniform_int_distribution<int> unif_int(0, V->size()-1);
    int idx = unif_int(gen);
    vector<int> indices = {idx};

    for(int i = 1; i < k; i++)
    {
        SQCentres sqc(V,indices);
        int new_idx = sqc.sample();
        indices.push_back(new_idx);
    }

    return indices;
}

vector<int> kpp(vector<vector<double>>& data, int k)
{
    vector<int> indices;
    int n = data.size();
    vector<double> min_dists(data.size());
    uniform_int_distribution<int> unif_int(0, data.size()-1);
    int idx = unif_int(gen);
    indices.push_back(idx);
    discrete_distribution<int> p(min_dists.begin(), min_dists.end());

    for(int i = 0; i < n; i++)
    {
        min_dists[i] = dist(data[i], data[idx]);
    }

    for(int i = 1; i < k; i++)
    {   
        int idx = p(gen);
        indices.push_back(idx);
        for(int j = 0; j < n; j++)
        {
            min_dists[j] = min(min_dists[j], dist(data[j], data[idx]));
        }

        p = discrete_distribution<int>(min_dists.begin(), min_dists.end());
    }

    return indices;

}

double dist(vector<double>& x, vector<double>& y)
{
    double sum = 0;
    for(int i = 0; i < x.size(); i++)
    {
        sum += (x[i]-y[i])*(x[i]-y[i]);
    }
    return (sum);
}

double norm(vector<double>& x)
{
    double sum = 0;
    for( auto x_ : x)
    {
        sum += x_*x_;
    }
    return sqrt(sum);
}


std::vector<std::vector<double>> readDataFromFile(const std::string& filename) {
    std::vector<std::vector<double>> data;
    std::ifstream file(filename);

    if (file.is_open()) {
        std::string line;
        while (std::getline(file, line)) {
            std::vector<double> row;
            std::stringstream ss(line);
            double value;
            while (ss >> value) {
                row.push_back(value);
            }
            data.push_back(row);
        }
        file.close();
    } else {
        std::cout << "Unable to open file";
    }

    return data;
}


double cost(vector<vector<double>>& data, vector<int> & indices)
{

    double cost = 0;
    for(int i = 0; i < data.size(); i++)
    {
        double min_dist = 1e9;
        for(int j = 0; j < indices.size(); j++)
        {
            double dist = 0;
            for(int k = 0; k < data[i].size(); k++)
            {
                dist += (data[i][k] - data[indices[j]][k]) * (data[i][k] - data[indices[j]][k]);
            }
            if(dist < min_dist)
            {
                min_dist = dist;
            }
        }
        cost += min_dist;
    }
    return cost;
}