from spaces.euclid import Euclid
from spaces.sphere import Sphere
from spaces.Matsumoto import Matsumoto
from spaces.Poincare import Poincare
from spaces.car_like_disk import CarLikeDisk
from spaces.panda import Panda

import torch as th
import numpy as np
import os
norm2 = lambda c: lambda s : c*(s**2).sum(dim=1)
d_norm2 = lambda c: lambda s: 2*c*s

def load_data(file_path):
    data = np.load(file_path)
    num = data.shape[0]
    return [[th.tensor(data[i][0], dtype = th.float32), th.tensor(data[i][1], dtype = th.float32)] for i in range(num)]

def pick_space(name):
    if name == "Plane":
        return Euclid(2), [[th.tensor([-1/2.,-1/2.]), th.tensor([1/2.,1/2.])]]
    elif name == "Sphere3":
        return Sphere(3), [[th.tensor([1/2,1/2,-1/2]), th.tensor([1/2,1/2,1/2])]]
    elif name == "Poincare3":
        return Poincare(3), [[th.tensor([1/2,1/2,-1/2]), th.tensor([1/2,1/2,1/2])]]
    elif name == "CarLikeDisk":
        return CarLikeDisk(), load_data(os.path.dirname(os.path.realpath(__file__))+ "/../data/"+name+"_eval_states.npy")
    elif name[:10] == "Matsumoto_":
        c = float(name[10:])
        return Matsumoto(dh = d_norm2(c)), load_data(os.path.dirname(os.path.realpath(__file__))+ "/../data/"+name+"_eval_states.npy")
    elif name[:10] == "Panda":
        return Panda(), load_data(os.path.dirname(os.path.realpath(__file__))+ "/../data/"+name+"_eval_states.npy")
    else:
        raise ValueError("Unknown Space")
