#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from hype.graph import eval_reconstruction, load_adjacency_matrix
import argparse
import numpy as np
import torch
import os
import timeit
from hype import MANIFOLDS, MODELS

def acosh(x):
  return np.log(np.sqrt(x*x - 1) + x)

np.random.seed(42)

parser = argparse.ArgumentParser()
parser.add_argument('file', help='Path to checkpoint')
parser.add_argument('-workers', default=1, type=int, help='Number of workers')
parser.add_argument('-sample', type=int, help='Sample size')
parser.add_argument('-quiet', action='store_true', default=False)
parser.add_argument('-bfkl', default='none', type=str)
parser.add_argument('-input3', default='none', type=str)
parser.add_argument('-coord', default='none', type=str)
parser.add_argument('-output', type=str, default="none")
parser.add_argument('-output3', type=str, default="none")
parser.add_argument('-rank', default=False)
args = parser.parse_args()

chkpnt = torch.load(args.file)
dset = chkpnt['conf']['dset']

if not os.path.exists(dset):
    raise ValueError("Can't find dset!")

format = 'hdf5' if dset.endswith('.h5') else 'csv'
dset = load_adjacency_matrix(dset, format, objects=chkpnt['objects'])

sample_size = args.sample or len(dset['ids'])
sample = np.random.choice(len(dset['ids']), size=sample_size, replace=False)

adj = {}

for i in sample:
    end = dset['offsets'][i + 1] if i + 1 < len(dset['offsets']) \
        else len(dset['neighbors'])
    adj[dset['ids'][i]] = set(dset['neighbors'][dset['offsets'][i]:end])
manifold = MANIFOLDS[chkpnt['conf']['manifold']]()

manifold = MANIFOLDS[chkpnt['conf']['manifold']]()
model = MODELS[chkpnt['conf']['model']](
    manifold,
    dim=chkpnt['conf']['dim'],
    size=chkpnt['embeddings'].size(0),
    sparse=chkpnt['conf']['sparse']
)
model.load_state_dict(chkpnt['model'])

lt = chkpnt['embeddings']
if not isinstance(lt, torch.Tensor):
    lt = torch.from_numpy(lt).cuda()


weights = model.state_dict()["lt.weight"]
qty = weights.shape[0]
print(dset)
#

wn = weights.numpy()

objects = dset["objects"]["obj"]
model.obj = objects

#to print everything:
if args.output != "none":
  f = open(args.output, 'w')
  print("n R alpha T", file = f)
  print(qty, "0 0 0", file = f)
  for key, val in enumerate(objects):
    v = wn[key]
    z = np.shape(v)[0]
    if z == 3:
      # Lorentz
      dist = acosh(v[0])
      phi = np.arctan2(v[1], v[2]) * 180 / np.pi + 180
    else:
      # Poincare
      d = v[0]*v[0]+v[1]*v[1]
      dist = acosh(1 + 2 * d / (1-d))
      phi = np.arctan2(v[0], v[1]) * 180 / np.pi + 180
    print(val, dist, phi, file = f)
  f.close()

#to print everything:
if args.output3 != "none":
  f = open(args.output3, 'w')
  for key, val in enumerate(objects):
    print(val, file = f)
    v = wn[key]
    for co in v:
      print(co*1, file = f)
  f.close()

if args.rank:
  tstart = timeit.default_timer()
  meanrank, maprank = eval_reconstruction(adj, model, workers=args.workers,
      progress=not args.quiet)
  etime = timeit.default_timer() - tstart
  print(f'Mean rank: {meanrank}, mAP rank: {maprank}, time: {etime}')

if args.bfkl != "none":

  coords = {}
  f = open(args.bfkl, "r")
  f.readline()
  f.readline()

  for x in f:
    sx = x.split("\t")
    name = sx[0]
    dist = float(sx[1])
    phi = float(sx[2])
    coords[name] = (dist, phi * np.pi / 180)

  for key, val in enumerate(objects):
    line = weights[key]
    line1 = line.clone()
    if not val in coords:
      print("not ", val, " in coords")
      continue
    (dist, phi) = coords[val]
    line[0] = np.cosh(dist)
    line[1] = np.sinh(dist) * np.cos(phi)
    line[2] = np.sinh(dist) * np.sin(phi)
    # print(val, " : ", line1, " -> ", line)

  if True:
    tstart = timeit.default_timer()
    meanrank, maprank = eval_reconstruction(adj, model, workers=args.workers,
        progress=not args.quiet)
    etime = timeit.default_timer() - tstart

    print(f'Mean rank: {meanrank}, mAP rank: {maprank}, time: {etime}')


if args.input3 != "none":

  coords = {}
  f = open(args.input3, "r")

  while True:
    name = f.readline().rstrip()
    if name == "":
      break
    x = float(f.readline())
    y = float(f.readline())
    z = float(f.readline())
    coords[name] = [x, y, z]

  for key, val in enumerate(objects):
    line = weights[key]
    line1 = line.clone()
    if not str(val) in coords:
      print("not '", val, "' in coords")
      continue
    (x,y,z) = coords[str(val)]
    line[0] = x
    line[1] = y
    line[2] = z

  if True:
    tstart = timeit.default_timer()
    meanrank, maprank = eval_reconstruction(adj, model, workers=args.workers,
        progress=not args.quiet)
    etime = timeit.default_timer() - tstart

    print(f'Mean rank: {meanrank}, mAP rank: {maprank}, time: {etime}')

if args.coord != "none":

  f = open(args.coord, "r")
  coords = {}
  while True:
    label = f.readline().rstrip('\n')
    values = f.readline()
    if label == "":
      break
    sx = values.split(" ")
    sx = sx[-1:] + sx[:-1] # different order
    coords[label] = sx
    #print(label, "got", sx)

  for key, val in enumerate(objects):
    line = weights[key]
    line1 = line.clone()
    if not val in coords:
      print("not ", val, " in coords")
      continue
    for k in range(200):
      #print(val, k, " : ", coords[val][k])
      line[k] = float(coords[val][k])
    #print(val, " : ", line1, " -> ", line)

  if True:
    tstart = timeit.default_timer()
    meanrank, maprank = eval_reconstruction(adj, model, workers=args.workers,
        progress=not args.quiet)
    etime = timeit.default_timer() - tstart

    print(f'Mean rank: {meanrank}, mAP rank: {maprank}, time: {etime}')
