#include "relabel_cpu.h"

// sorting function which returns sorted index
template <typename T>
std::vector<size_t> sort_indexes(const std::vector<T> &v) {
  // initialize original index locations
  std::vector<size_t> idx(v.size());
  std::iota(idx.begin(), idx.end(), 0);
  // sort indexes based on comparing values in v
  // using std::stable_sort instead of std::sort
  // to avoid unecessary index re-orderings
  // when v contains elements of equal values
  stable_sort(idx.begin(), idx.end(),
       [&v](size_t i1, size_t i2) {return v[i1] < v[i2];});
  return idx;
}

std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
           torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
                    torch::optional<torch::Tensor> optional_value,
                    torch::Tensor idx, bool bipartite) {

  AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
  AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
  if (optional_value.has_value()) {
    auto value = optional_value.value();
    AT_ASSERTM(!value.is_cuda(), "Value tensor must be a CPU tensor");
    AT_ASSERTM(value.dim() == 1, "Value tensor must be one-dimensional");
  }
  AT_ASSERTM(!idx.is_cuda(), "Index tensor must be a CPU tensor");

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto idx_data = idx.data_ptr<int64_t>();

  std::vector<int64_t> n_ids;
  std::unordered_map<int64_t, int64_t> n_id_map;
  std::unordered_map<int64_t, int64_t>::iterator it;

  auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
  auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();

  out_rowptr_data[0] = 0;
  int64_t v, w, c, row_start, row_end, offset = 0;
  for (int64_t i = 0; i < idx.numel(); i++) {
    v = idx_data[i];
    n_id_map[v] = i;
    offset += rowptr_data[v + 1] - rowptr_data[v];
    out_rowptr_data[i + 1] = offset;
  }

  auto out_col = torch::empty(offset, col.options());
  auto out_col_data = out_col.data_ptr<int64_t>();

  torch::optional<torch::Tensor> out_value = torch::nullopt;
  if (optional_value.has_value()) {
    out_value = torch::empty(offset, optional_value.value().options());

    AT_DISPATCH_ALL_TYPES(optional_value.value().scalar_type(), "relabel", [&] {
      auto value_data = optional_value.value().data_ptr<scalar_t>();
      auto out_value_data = out_value.value().data_ptr<scalar_t>();

      offset = 0;
      for (int64_t i = 0; i < idx.numel(); i++) {
        v = idx_data[i];
        row_start = rowptr_data[v], row_end = rowptr_data[v + 1];

        for (int64_t j = row_start; j < row_end; j++) {
          w = col_data[j];
          it = n_id_map.find(w);
          if (it == n_id_map.end()) {
            c = idx.numel() + n_ids.size();
            n_id_map[w] = c;
            n_ids.push_back(w);
            out_col_data[offset] = c;
          } else {
            out_col_data[offset] = it->second;
          }
          out_value_data[offset] = value_data[j];
          offset++;
        }
      }
    });

  } else {
    offset = 0;
    for (int64_t i = 0; i < idx.numel(); i++) {
      v = idx_data[i];
      row_start = rowptr_data[v], row_end = rowptr_data[v + 1];

      for (int64_t j = row_start; j < row_end; j++) {
        w = col_data[j];
        it = n_id_map.find(w);
        if (it == n_id_map.end()) {
          c = idx.numel() + n_ids.size();
          n_id_map[w] = c;
          n_ids.push_back(w);
          out_col_data[offset] = c;
        } else {
          out_col_data[offset] = it->second;
        }
        offset++;
      }
    }
  }

  if (!bipartite)
    out_rowptr = torch::cat(
        {out_rowptr, torch::full({(int64_t)n_ids.size()}, out_col.numel(),
                                 rowptr.options())});

  idx = torch::cat({idx, torch::from_blob(n_ids.data(), {(int64_t)n_ids.size()},
                                          idx.options())});

  return std::make_tuple(out_rowptr, out_col, out_value, idx);
}

std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
           torch::Tensor>
generate_contiguous_heterograph_cpu(torch::Tensor rowptr, torch::Tensor col,
                    torch::optional<torch::Tensor> optional_value,
                    torch::Tensor idx, bool bipartite) {

  AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
  AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
  if (optional_value.has_value()) {
    auto value = optional_value.value();
    AT_ASSERTM(!value.is_cuda(), "Value tensor must be a CPU tensor");
    AT_ASSERTM(value.dim() == 1, "Value tensor must be one-dimensional");
  }
  AT_ASSERTM(!idx.is_cuda(), "Index tensor must be a CPU tensor");

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto idx_data = idx.data_ptr<int64_t>();

  std::vector<int64_t> n_ids;
  std::unordered_map<int64_t, int64_t> n_id_map;
  std::unordered_map<int64_t, int64_t>::iterator it;

  auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
  auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();

  out_rowptr_data[0] = 0;
  int64_t v, w, c, row_start, row_end, offset = 0;
  for (int64_t i = 0; i < idx.numel(); i++) {
    // iterate the whole elements of the current metis partition
    v = idx_data[i]; // original vertex id: v
    n_id_map[v] = i; // v -> i
    offset += rowptr_data[v + 1] - rowptr_data[v];
    // how many vertices v (i) reaches
    out_rowptr_data[i + 1] = offset;
  }

  auto out_col = torch::empty(offset, col.options());
  auto out_col_data = out_col.data_ptr<int64_t>();

  torch::optional<torch::Tensor> out_value = torch::nullopt;
  if (optional_value.has_value()) {
    out_value = torch::empty(offset, optional_value.value().options());

    AT_DISPATCH_ALL_TYPES(optional_value.value().scalar_type(), "relabel", [&] {
      auto value_data = optional_value.value().data_ptr<scalar_t>();
      auto out_value_data = out_value.value().data_ptr<scalar_t>();

      for (int64_t i = 0; i < idx.numel(); i++) {
        v = idx_data[i];
        row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
        for (int64_t j = row_start; j < row_end; j++) {
          w = col_data[j];
          it = n_id_map.find(w);
          if (it == n_id_map.end()) {
            n_id_map[w] = -1;
            n_ids.push_back(w);
          }
        }
      }
      
      stable_sort(n_ids.begin(), n_ids.end());
      for (int64_t i = 0; i < n_ids.size(); i++) {
        n_id_map[n_ids[i]] = idx.numel() + i;
      }

      offset = 0;
      for (int64_t i = 0; i < idx.numel(); i++) {
        v = idx_data[i];
        row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
        for (int64_t j = row_start; j < row_end; j++) {
          w = col_data[j];
          it = n_id_map.find(w);
          assert(it != n_id_map.end()); // w should already be added.
          out_col_data[offset] = it->second;
          out_value_data[offset] = value_data[j];
          offset++;
        }
      }

    });

  } else {

    for (int64_t i = 0; i < idx.numel(); i++) {
      v = idx_data[i];
      row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
      for (int64_t j = row_start; j < row_end; j++) {
        w = col_data[j];
        it = n_id_map.find(w);
        if (it == n_id_map.end()) {
          n_id_map[w] = -1;
          n_ids.push_back(w);
        }
      }
    }
    
    stable_sort(n_ids.begin(), n_ids.end());
    for (int64_t i = 0; i < n_ids.size(); i++) {
      n_id_map[n_ids[i]] = idx.numel() + i;
    }

    offset = 0;
    for (int64_t i = 0; i < idx.numel(); i++) {
      v = idx_data[i];
      row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
      for (int64_t j = row_start; j < row_end; j++) {
        w = col_data[j];
        it = n_id_map.find(w);
        assert(it != n_id_map.end()); // w should already be added.
        out_col_data[offset] = it->second;
        offset++;
      }
    }
  }

  if (!bipartite)
    out_rowptr = torch::cat(
        {out_rowptr, torch::full({(int64_t)n_ids.size()}, out_col.numel(),
                                 rowptr.options())});

  idx = torch::cat({idx, torch::from_blob(n_ids.data(), {(int64_t)n_ids.size()},
                                          idx.options()).clone()});

  return std::make_tuple(out_rowptr, out_col, out_value, idx);
}