import os
import shutil
import numpy as np
import pandas as pd
import torch
import pickle
# !pip install "transformers==2.5.1"

# pip install sentence-transformers

def make_data():
    with open("../data/ec/final/audio.pkl", "rb") as handle:
        audio = pickle.load(handle)
    
    with open("../data/ec/final/text.pkl", "rb") as handle:
        text = pickle.load(handle)
    
    with open("../data/ec/final/gender.pkl", "rb") as handle:
        gender = pickle.load(handle)
        gender = torch.tensor([0 if g == 'M' else 1 for g in gender])
    
    lens = [len(a) for a in audio]
    max_len = max(lens)
    new_audio, new_text = [], []
    for a in audio:
        while len(a) != max_len:
            a.append([1]*len(a[0]))
        new_audio.append(a)
    audio = torch.tensor(new_audio)
    audio = torch.nan_to_num(audio, 0)
    audio_mean = torch.mean(audio)
    audio_std = torch.std(audio)
    audio = (audio - audio_mean) / audio_std


    for t in text:
        while len(t) != max_len:
            t.append([1e-5]*len(t[0]))
        new_text.append(t)
    text = torch.tensor(new_text)

    return audio, text


audio, text = make_data()
cat = torch.cat((audio, text), dim=2)
cat = torch.flatten(cat, 1)
print(cat.shape)

from geoopt.manifolds.stereographic import math as math1

import numpy as np
from tqdm import tqdm
import torch

sentence_embeddings = math1.expmap0(cat,k=torch.tensor([1.]))
matrix = np.zeros((cat.shape[0], cat.shape[0]))

for i in tqdm(range(len(sentence_embeddings))):
  for j in range(i,len(sentence_embeddings)):
    matrix[i][j] = math1.dist(torch.tensor(sentence_embeddings[i]),torch.tensor(sentence_embeddings[j]),k=torch.tensor(1.))
    matrix[j][i] = matrix[i][j]

# Use this matrix for hyperbolic training

matrix = np.array(matrix)


"""Euclidean"""

# matrix = np.zeros((len(df),len(df)))

# for i in tqdm(range(len(sentence_embeddings))):
#   for j in range(i,len(sentence_embeddings)):
#     matrix[i][j] = np.linalg.norm(sentence_embeddings[i]-sentence_embeddings[j])
#     matrix[j][i] = matrix[i][j]

# # Use this matrix for Euclidean training
# matrix = np.array(matrix)
np.save("../data/ec/final/train.npy", matrix)
