#include "../som/kohonen.h"

namespace dhrg {

/* build landscape embeddings and compute mAP for them */

using rogueviz::kohonen::kohvec;
using rogueviz::kohonen::columns;

int landscape_dim;

vector<kohvec> landscape_weights;

ld dist(const kohvec& a, const kohvec& b) {
  ld res = 0;
  for(int i=0; i<columns; i++) res += (a[i]-b[i]) * (a[i] - b[i]);
  return res;
  }

void landscape_build(int dim) {
  rogueviz::embeddings::init_landscape(columns = landscape_dim = dim);
  landscape_weights.resize(N);
  if(1) {
    progressbar pb(N, "compute landscape");
    for(int i=0; i<N; i++) {
      mycell *mc = vertices[i];
      cell *c = mc->ascell();
      landscape_weights[i] = rogueviz::embeddings::get_landscape_at(c);
      // println(hlog, i, ": ", landscape_weights[i]);
      pb++;
      }
    }
  println(hlog, "delta size = ", isize(rogueviz::embeddings::delta_at));
  }

void landscape_output(string shape, ld target_dist, string fname) {
  ld max_dist = 0;
  for(int i=0; i<N; i++) {
    ld tot = 0;
    for(int d=0; d<columns; d++) tot += pow(landscape_weights[i][d], 2);
    tot = sqrt(tot);
    max_dist = max(max_dist, tot);
    }
  println(hlog, "max_dist = ", max_dist, " to ", target_dist);
  fhstream f(fname, "w");
  for(int i=0; i<N; i++) {
    ld tot = 0;
    println(f, rogueviz::vdata[i].name);
    for(int d=0; d<columns; d++) {
      auto w = landscape_weights[i][d] * target_dist / max_dist;
      if(d) print(f, " ");
      print(f, w);
      tot += w * w;
      }
    if(shape == "euclid") println(f);
    else if(shape == "poincare") println(f);
    else if(shape == "lorentz") println(f, " ", sqrt(tot+1));
    else if(shape == "horosphere") {
      tot /= 2;
      ld shift = -asinh(target_dist)/2;
      println(f, " ", cosh(shift) * tot + sinh(shift) * (1+tot), " ", cosh(shift) * (1+tot) + sinh(shift) * tot);
      }
    }
  }

ld landscape_dist(int i, int j) {
  return dist(landscape_weights[i], landscape_weights[j]);
  } 

void rank_landscape(int dim) {
  landscape_build(dim);
  continuous_ranks(landscape_dist); 
  }

/* read Poincare embedding and compute ranks and routing measures */

void read_poincare(string fname) {
  fhstream g(fname, "rt");
  if(!g.f) {
    println(hlog, "Missing file: ", fname);
    exit(1);
    }
  println(hlog, "Reading Poincare (dim = ", GDIM, ")...");
  vertexcoords.resize(N);
  ld maxradius = 0;
  while(true) {
    int i = rogueviz::readLabel(g);
    if(i == -1) break;
    hyperpoint h;
    for(int i=0; i<GDIM; i++) h[i] = scan<ld>(g);
    auto h1 = perspective_to_space(h, 1);
    vertexcoords[i] = h1;
    maxradius = max(maxradius, hdist0(h1));
    }
  println(hlog, "maximum radius = ", maxradius);
  }

void read_lorentz(string fname) {
  fhstream g(fname, "rt");
  if(!g.f) {
    println(hlog, "Missing file: ", fname);
    exit(1);
    }
  println(hlog, "Reading Lorentz (dim = ", GDIM, ")...");
  vertexcoords.resize(N);
  ld maxradius = 0;
  while(true) {
    int i = rogueviz::readLabel(g);
    if(i == -1) break;
    hyperpoint h;
    h[GDIM] = scan<ld>(g);
    for(int i=0; i<GDIM; i++) h[i] = scan<ld>(g);
    vertexcoords[i] = h;
    maxradius = max(maxradius, hdist0(h));
    }
  println(hlog, "maximum radius = ", maxradius);
  }

int count_edges() {
  int res = 0;
  for(int i=0; i<N; i++) res += isize(rogueviz::vdata[i].edges);
  return res / 2;
  }

void full_evaluation(const zerodistfun& zf, const distfun& df, int dim) {
  if(directed_edges.empty()) {
    println(hlog, "error: called full_evaluation without access to directed_edges");
    }
  bool symmetric = count_directed_edges() == 2 * count_edges();
  if(symmetric) routing_by(df);
  continuous_ranks(df);
  analyze_mdl_symmetric(zf, df, dim, symmetric);
  }

void eval_poincare3(string fname) {
  dynamicval<eGeometry> dg(geometry, gSpace534);
  read_poincare(fname);
  full_evaluation(vertex_zero_dist, vertex_dist, 3);
  }

void eval_poincare2(string fname) {
  read_poincare(fname);
  full_evaluation(vertex_zero_dist, vertex_dist, 2);
  }

void eval_lorentz3(string fname) {
  dynamicval<eGeometry> dg(geometry, gSpace534);
  read_lorentz(fname);
  full_evaluation(vertex_zero_dist, vertex_dist, 3);
  }

void eval_lorentz2(string fname) {
  read_lorentz(fname);
  full_evaluation(vertex_zero_dist, vertex_dist, 2);
  }

void eval_disttable(string nodelist_fname, string disttable_fname) {
  vector<vector<ld>> disttable;
  fhstream g(nodelist_fname, "rt");
  vector<int> our_index;
  while(true) {
    int i = rogueviz::readLabel(g);
    if(i == -1) break;
    our_index.push_back(i);
    }
  println(hlog, "read ", isize(our_index), " node labels from ", nodelist_fname);
  if(isize(our_index) != N) { println(hlog, "wrong number of labels, ", N, " expected"); return; }

  ld maxdist = 0;

  disttable.resize(N);
  for(int i=0; i<N; i++) disttable[i].resize(N, 0);
  fhstream g1(disttable_fname, "rt");
  println(hlog, "reading distance table from ", disttable_fname);
  for(int i=0; i<N; i++)
  for(int j=0; j<i; j++) {
    ld dist = scan<ld>(g1);
    maxdist = max(dist, maxdist);
    disttable[our_index[i]][our_index[j]] = dist;
    disttable[our_index[j]][our_index[i]] = dist;
    }

  println(hlog, "maximum distance = ", maxdist);

  auto disttable_dist = [&] (int i, int j) { return disttable[i][j]; };
  routing_by(disttable_dist);
  continuous_ranks(disttable_dist);
  }

struct mercator_coord {
  ld distance;
  array<ld, 3> h;
  };

void eval_mercator3(string fname, bool only_read = false) {
  vector<mercator_coord> mcs;
  fhstream g(fname, "rt");
  mcs.resize(N);

  ld maxradius = 0;

  int qty = 0;

  vertexcoords.resize(N);
  while(true) {
    string s = scan<string>(g);
    if(s == "") break;
    if(s == "#") { scanline_noblank(g); continue; }
    int id = rogueviz::getid(s);
    if(id == -1) { printf("got -1\n"); exit(1); }
    ld kappa = scan<ld> (g); hr::ignore(kappa);
    auto& mc = mcs[id];
    mc.distance = scan<ld> (g);
    for(int i=0; i<3; i++) mc.h[i] = scan<ld> (g);
    ld radius = mc.h[0] * mc.h[0] + mc.h[1] * mc.h[1] + mc.h[2] * mc.h[2];
    radius = sqrt(radius);
    for(int i=0; i<3; i++) mc.h[i] /= radius;
    qty++;
    maxradius = max(maxradius, mc.distance);
    for(int i=0; i<3; i++) 
      vertexcoords[id][i] = sinh(mc.distance) * mc.h[i];
    vertexcoords[id][3] = cosh(mc.distance);
    }

  if(only_read) return;

  auto mercator_dist = [&] (int a, int b) {
    if(a == b) return ld(0);
    ld da = mcs[a].distance;
    ld db = mcs[b].distance;

    ld cosphi = 0;
    for(int i=0; i<3; i++) cosphi += mcs[a].h[i] * mcs[b].h[i];
  
    ld co = sinh(da) * sinh(db) * (1 - cosphi);  
    ld v = cosh(da - db) + co;
    if(v < 1) return ld(0);
  
    return acosh(v);
    };

  full_evaluation([&] (int id) { return mcs[id].distance; }, mercator_dist, 3);
  }

void eval_euclid(string fname, int dim) {
  vector<vector<ld>> coords;
  vector<ld> dist0;
  println(hlog, "Open file ", fname);
  fhstream g(fname, "rt");
  if(!g.f) {
    println(hlog, "Missing file: ", fname);
    exit(1);
    }
  println(hlog, "Reading Euclid (dim = ", dim, ")...");
  coords.resize(N);
  dist0.resize(N);
  ld maxradius = 0;  
  while(true) {
    int i = rogueviz::readLabel(g);
    if(i == -1) break;
    vector<ld> co(dim);
    for(int i=0; i<dim; i++) co[i] = scan<ld>(g);

    ld d = 0; for(int i=0; i<dim; i++) d += co[i] * co[i];
    dist0[i] = sqrt(d);

    coords[i] = std::move(co);
    maxradius = max(maxradius, dist0[i]);
    }
  println(hlog, "maximum radius = ", maxradius);

  vector<vector<ld>> disttable;
  disttable.resize(N);
  for(int i=0; i<N; i++) disttable[i].resize(N, 0);
  for(int i=0; i<N; i++)
  for(int j=0; j<=i; j++) {
    ld d = 0; for(int k=0; k<dim; k++) d += pow(coords[i][k]-coords[j][k], 2);
    disttable[i][j] = disttable[j][i] = sqrt(d);
    }

  auto disttable_dist = [&] (int i, int j) { return disttable[i][j]; };
  routing_by(disttable_dist);
  continuous_ranks(disttable_dist);
  }

void eval_poincare_any(string fname, int dim) {
  vector<vector<ld>> coords;
  vector<ld> dist0;
  println(hlog, "Open file ", fname);
  fhstream g(fname, "rt");
  if(!g.f) {
    println(hlog, "Missing file: ", fname);
    exit(1);
    }
  println(hlog, "Reading Poincare (dim = ", dim, ")...");
  coords.resize(N);
  dist0.resize(N);
  ld maxradius = 0;  
  while(true) {
    int i = rogueviz::readLabel(g);
    if(i == -1) break;
    vector<ld> co(dim+1);
    for(int i=0; i<dim; i++) co[i] = scan<ld>(g);

    ld d = 0; for(int i=0; i<dim; i++) d += co[i] * co[i];
    hyperpoint h; h[0] = sqrt(d); h[1] = 0; auto h1 = perspective_to_space(h, 1);

    for(int i=0; i<dim; i++) co[i] = co[i] * h1[0] / d; co[dim] = h1[2];

    dist0[i] = acos_auto_clamp(h1[2]);

    coords[i] = std::move(co);
    maxradius = max(maxradius, dist0[i]);
    }
  println(hlog, "maximum radius = ", maxradius);

  vector<vector<ld>> disttable;
  disttable.resize(N);
  for(int i=0; i<N; i++) disttable[i].resize(N, 0);
  for(int i=0; i<N; i++)
  for(int j=0; j<=i; j++) {
    ld d = 0; for(int k=0; k<dim; k++) d += pow(coords[i][k]-coords[j][k], 2);
    d -= pow(coords[i][dim]-coords[j][dim], 2);
    disttable[i][j] = disttable[j][i] = acos_auto_clamp(d);
    }

  auto disttable_dist = [&] (int i, int j) { return disttable[i][j]; };
  routing_by(disttable_dist);
  continuous_ranks(disttable_dist);
  }

}
