# %%
import sys
from pathlib import Path
# Get the parent directory of the current file
parent_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(parent_dir))
import numpy as np
from cblearn.datasets import make_random_triplet_indices, noisy_triplet_response
from cblearn.datasets import triplet_response
from cblearn.metrics import query_accuracy
from cblearn.embedding import SOE
from cblearn.embedding import FORTE
from lore.triplet_algorithm import train_lore
from lore.f import LogisticTripletLoss
from lore.g import SchattenPUnnormed
# Make the test set
from utils.test_set import sample_test_set
from utils.psnr import compute_psnr, normalized_procrustes_distance
import torch
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from joblib import Parallel, delayed
import math
import glob
import pickle
from utils.sentence_utils import encode_sentences, reduce_dimensionality


# %%
NUM_PERCEPTS = 50
lamb = 0.01
MAX_ITERATIONS = 1000
p = 0.5
NUM_TEST_SET = 3000
MAX_EMBEDDING_DIM = 15

# %%
# load the food data
import pickle
food_100_data = pickle.load(open("real_data/food_100_complete.pkl", "rb"))


# %%
# Write the names to a csv file
food_100_names = food_100_data["names"]
food_100_tastes = food_100_data["tastes"]
# Make df with both
food_100_df = pd.DataFrame({"name": food_100_names, "taste": food_100_tastes})
# Shuffle the rows
food_100_df = food_100_df.sample(frac=0.5).reset_index(drop=True)
# shuffle and save only the first 50 names
food_100_df.to_csv("real_data/food_100_sampled_50.csv", index=False, header=False)

# %%
# Get sentence embeddings
filename = glob.glob("real_data/food_100_sampled_50.csv")
embeddings, sentences = encode_sentences(filename[0])

# %%
# Main Experiment Hyperparameters
INTRINSIC_RANKS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
NUM_RUNS = 30
NUM_PERCEPTS=50
NOISE=0.1
NUM_QUERIES = math.comb(len(sentences), 3)*3
NUM_TRAIN_FRACTION = 0.1
NUM_TRAIN_QUERIES = math.floor(NUM_QUERIES * NUM_TRAIN_FRACTION)

# %%
def each_run(true_dimension, num_runs=30):
    human_embeddings = reduce_dimensionality(embeddings, true_dimension)
    true_distances = squareform(pdist(human_embeddings))
    
    logging = {"FORTE": {"embedding":[], "test_accuracy":[], "psnr":[], "procrustes":[], "rank":[]}, "SOE": {"embedding":[], "test_accuracy":[], "psnr":[], "procrustes":[], "rank":[]}, "LORE": {"embedding":[], "test_accuracy":[], "psnr":[], "procrustes":[], "rank":[]}}
    for i in range(num_runs):
        train_triplets = make_random_triplet_indices(n_objects=NUM_PERCEPTS, size=NUM_TRAIN_QUERIES, repeat=True, random_state=i)
        if NOISE <=0:
            noisy_train_response = triplet_response(train_triplets, true_distances, distance='precomputed')
        else:
            noisy_train_response = noisy_triplet_response(train_triplets, true_distances, distance='precomputed', noise='normal', noise_options={'scale': NOISE}, random_state=i)
        # Get test set
        test_set = sample_test_set(train_triplets, NUM_PERCEPTS, NUM_TEST_SET, seed=i)
        true_test_response = triplet_response(test_set, true_distances, distance='precomputed')
        
        # Low Rank OE
        device = torch.device("cuda")
        torch.manual_seed(1)
        X = np.random.randn(NUM_PERCEPTS, MAX_EMBEDDING_DIM)
        soe_X = X.copy()
        lore_X = X.copy()
        # Send lore_X to torch and GPU
        device = torch.device('cuda:1')
        lore_X = torch.tensor(lore_X, device=device, dtype=torch.float32)
        FORTE_X = X.copy()
        # Initialize loss functions
        f = LogisticTripletLoss(triplets=torch.from_numpy(noisy_train_response).long().to(device), margin=1)
        g = SchattenPUnnormed(p)
        # Run PIRNN algorithm
        results=train_lore(lore_X, lamb, f, g, mu=0.1, max_iterations=MAX_ITERATIONS, tol=1e-6, seed=-1, zero=1e-15, verbose=False)

        # SOE
        soe_embedder = SOE(n_components=15, random_state=i)
        soe_embedder.embedding_ = soe_X
        soe_embedding = soe_embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)

        #FORTE
        forte_embedder = FORTE(n_components=MAX_EMBEDDING_DIM, random_state=i)
        forte_embedder.embedding_ = FORTE_X
        forte_embedding = forte_embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)

        # low rank test set
        low_rank_test_response = triplet_response(test_set, results["X"])
        low_rank_test_accuracy = query_accuracy(true_test_response, low_rank_test_response)
        # Get aligned embeddings and psnr
        low_rank_psnr = compute_psnr(human_embeddings, results["X"])
        low_rank_rank = np.linalg.matrix_rank(results["X"])
        low_rank_procrustes = normalized_procrustes_distance(human_embeddings, results["X"])

        # SOE test set
        soe_test_response = triplet_response(test_set, soe_embedding)
        soe_test_accuracy = query_accuracy(true_test_response, soe_test_response)
        # Get aligned embeddings and psnr
        soe_psnr = compute_psnr(human_embeddings, soe_embedding)
        soe_rank = np.linalg.matrix_rank(soe_embedding)
        soe_procrustes = normalized_procrustes_distance(human_embeddings, soe_embedding)

        # Forte Test set
        forte_test_response = triplet_response(test_set, forte_embedding)
        forte_test_accuracy = query_accuracy(true_test_response, forte_test_response)
        # Get aligned embeddings and psnr
        forte_psnr = compute_psnr(human_embeddings, forte_embedding)
        forte_rank = np.linalg.matrix_rank(forte_embedding)
        forte_procrustes = normalized_procrustes_distance(human_embeddings, forte_embedding)

        # Logging
        logging["FORTE"]["embedding"].append(forte_embedding)
        logging["FORTE"]["test_accuracy"].append(forte_test_accuracy)
        logging["FORTE"]["psnr"].append(forte_psnr)
        logging["FORTE"]["rank"].append(forte_rank)
        logging["FORTE"]["procrustes"].append(forte_procrustes)
        logging["SOE"]["embedding"].append(soe_embedding)
        logging["SOE"]["test_accuracy"].append(soe_test_accuracy)
        logging["SOE"]["psnr"].append(soe_psnr)
        logging["SOE"]["rank"].append(soe_rank)
        logging["SOE"]["procrustes"].append(soe_procrustes)
        logging["LORE"]["embedding"].append(results["X"])
        logging["LORE"]["test_accuracy"].append(low_rank_test_accuracy)
        logging["LORE"]["psnr"].append(low_rank_psnr)
        logging["LORE"]["rank"].append(low_rank_rank)
        logging["LORE"]["procrustes"].append(low_rank_procrustes)
    
    with open(f"results/artificial_perceptual_experiment/artificial_perceptual_experiment_intrinsic_rank{true_dimension}.pkl", "wb") as f:
        pickle.dump(logging, f)



# %%
# RUn in parallel instead
print("About to start Running")
Parallel(n_jobs=6)(delayed(each_run)(true_dimension, NUM_RUNS) for true_dimension in INTRINSIC_RANKS)


