# %%
import sys
from pathlib import Path
# Get the parent directory of the current file
parent_dir = Path(__file__).resolve().parent.parent
# Add the parent directory to sys.path
sys.path.insert(0, str(parent_dir))
from cblearn.datasets import LinearSubspace
import numpy as np
from cblearn.datasets import make_random_triplet_indices, noisy_triplet_response
from cblearn.datasets import triplet_response
from cblearn.metrics import query_accuracy
from cblearn.embedding import SOE, FORTE, TSTE, CKL
from lore.triplet_algorithm import train_lore
from lore.f import LogisticTripletLoss
from lore.g import SchattenPUnnormed
# Make the test set
from utils.test_set import sample_test_set
from utils.psnr import compute_psnr, normalized_procrustes_distance
import torch
import pandas as pd
import os
import math
from scipy.spatial.distance import pdist, squareform
from pqdm.processes import pqdm

# %%
MAX_EMBEDDING_DIM = 15
lambs = list(np.logspace(-2.2, -1, 15)) 
MAX_ITERATIONS = 1000
p = 0.5
NUM_TRAIN_FRACTION = 0.1
NUM_TEST_SET = 3000
INTRINSIC_RANK = 5
NOISE = 0.1
NUM_PERCEPTS = 50
NUM_QUERIES = math.comb(50, 3) * 3
NUM_TRAIN_QUERIES = math.floor(NUM_QUERIES * NUM_TRAIN_FRACTION)
NUM_WORKERS = 5
print(f"NUM_TRAIN_QUERIES: {NUM_TRAIN_QUERIES}")

# %%
def get_true_points(seed):
    manifold = LinearSubspace(subspace_dimension=INTRINSIC_RANK, space_dimension=15, random_state=seed)
    true_points, _ = manifold.sample_points(NUM_PERCEPTS, random_state=seed*69)
    # Normalize points by frobenius norm
    true_points = true_points / np.linalg.norm(true_points, 'fro')
    # Get new true distances
    true_distances = squareform(pdist(true_points, metric='euclidean'))
    return true_points, true_distances

def get_train_triplets(true_distances, seed):
    train_triplets = make_random_triplet_indices(n_objects=NUM_PERCEPTS, size=NUM_TRAIN_QUERIES, repeat=True, random_state=seed)
    if NOISE <=0:
        noisy_train_response = triplet_response(train_triplets, true_distances, distance='precomputed')
    else:
        noisy_train_response = noisy_triplet_response(train_triplets, true_distances, distance='precomputed', noise='normal', noise_options={'scale': NOISE}, random_state=seed)
    return train_triplets, noisy_train_response

def get_true_test_response(train_triplets, true_distances, seed=1):
    test_set = sample_test_set(train_triplets, NUM_PERCEPTS, NUM_TEST_SET, seed=seed)
    true_test_response = triplet_response(test_set, true_distances, distance='precomputed')
    return test_set, true_test_response

def get_metrics(embedding, true_points, test_set, true_test_response):
    rank = np.linalg.matrix_rank(embedding)
    predicted_test_response = triplet_response(test_set, embedding)
    # Query accuracy
    test_accuracy = query_accuracy(true_test_response, predicted_test_response)
    # Procrustes Distance
    procrustes_distance = normalized_procrustes_distance(true_points, embedding)
    # PSNR
    psnr= compute_psnr(true_points, embedding)
    metrics = {
        'rank': rank,
        'test_accuracy': test_accuracy,
        'procrustes_distance': procrustes_distance,
        'psnr': psnr
    }
    return metrics

def train_forte(embedding_dim, noisy_train_response, seed):
    forte_embedder = FORTE(n_components=embedding_dim,random_state=seed)
    forte_embedding = forte_embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)
    return forte_embedding

def train_tste(embedding_dim, noisy_train_response, seed):
    tste_embedder = TSTE(n_components=embedding_dim,backend='torch',random_state=seed)
    tste_embedding = tste_embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)
    return tste_embedding

def train_ckl(embedding_dim, noisy_train_response, seed):
    ckl_embedder = CKL(n_components=embedding_dim, mu=1,backend='torch', random_state=seed)
    ckl_embedding = ckl_embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)
    return ckl_embedding

def train_soe(embedding_dim, noisy_train_response, seed):
    device = torch.device("cuda")
    embedder = SOE(n_components=embedding_dim, backend='torch', n_init=1, restart_optim=1,  random_state=seed)
    embedding = embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)
    return embedding

def train_lore_algorithm(noisy_train_response, lamb, seed):
    device = torch.device("cuda")
    torch.manual_seed(1)
    X = torch.randn(NUM_PERCEPTS, 15, device=device, requires_grad=True)
    # Initialize loss functions
    f = LogisticTripletLoss(triplets=torch.from_numpy(noisy_train_response).long().to(device), margin=1)
    g = SchattenPUnnormed(p)
    # Run PIRNN algorithm
    results=train_lore(X, lamb, f, g, mu=0.1, max_iterations=MAX_ITERATIONS, tol=1e-6, seed=seed, zero=1e-15, verbose=False)

    return results["X"]


# %%
def train_one_combination(seed):
    all_metrics = []
    # Get true points and distances
    true_points, true_distances = get_true_points(seed)
    # Get training triplets and responses
    train_triplets, noisy_train_response = get_train_triplets(true_distances, seed)
    # Get test set and true test response
    test_set, true_test_response = get_true_test_response(train_triplets, true_distances, seed=seed)
    X = np.random.randn(NUM_PERCEPTS, MAX_EMBEDDING_DIM)
    
    for embedding_dim in range(MAX_EMBEDDING_DIM, MAX_EMBEDDING_DIM + 1):
        # Train SOE
        soe_embedding = train_soe(embedding_dim, noisy_train_response.copy(), seed*20)
        soe_metrics = get_metrics(soe_embedding, true_points, test_set, true_test_response)
        soe_metrics['embedding_dim'] = embedding_dim
        soe_metrics['method'] = 'SOE'
        soe_metrics['seed'] = seed
        soe_metrics['lambda'] = 0
        all_metrics.append(soe_metrics)

        # Train FORTE
        forte_embedding = train_forte(embedding_dim, noisy_train_response.copy(), seed*20)
        forte_metrics = get_metrics(forte_embedding, true_points, test_set, true_test_response)
        forte_metrics['embedding_dim'] = embedding_dim
        forte_metrics['method'] = 'FORTE'
        forte_metrics['seed'] = seed
        forte_metrics['lambda'] = 0
        all_metrics.append(forte_metrics)

        # Train TSTE
        tste_embedding = train_tste(embedding_dim, noisy_train_response.copy(), seed*20)
        tste_metrics = get_metrics(tste_embedding, true_points, test_set, true_test_response)
        tste_metrics['embedding_dim'] = embedding_dim
        tste_metrics['method'] = 'TSTE'
        tste_metrics['seed'] = seed
        tste_metrics['lambda'] = 0
        all_metrics.append(tste_metrics)

        # Train CKL
        ckl_embedding = train_ckl(embedding_dim, noisy_train_response.copy(), seed*20)
        ckl_metrics = get_metrics(ckl_embedding, true_points, test_set, true_test_response)
        ckl_metrics['embedding_dim'] = embedding_dim
        ckl_metrics['method'] = 'CKL'
        ckl_metrics['seed'] = seed
        ckl_metrics['lambda'] = 1
        all_metrics.append(ckl_metrics)
        
        # Train LORE
    for lamb in lambs:
        lore_embedding = train_lore_algorithm(noisy_train_response.copy(), lamb, seed*30)
        # print singular values
        lore_metrics = get_metrics(lore_embedding, true_points, test_set, true_test_response)
        lore_metrics['embedding_dim'] = embedding_dim
        lore_metrics['lambda'] = lamb
        lore_metrics['method'] = 'LORE'
        lore_metrics['seed'] = seed    
        all_metrics.append(lore_metrics)
    return all_metrics

def log_results(all_metrics, output_dir):
    # Create a DataFrame from the metrics
    df = pd.DataFrame(all_metrics)
    # Save the DataFrame to a CSV file
    filename = "results/figure1/figure1.csv"
    df.to_csv(os.path.join(output_dir, filename), index=False)


# %%
# Train Loop
run_ids = list(range(1, 31))
all_results = pqdm(run_ids, train_one_combination, n_jobs=NUM_WORKERS, desc="Param Combos", unit='combo')


# %%
print(f"All results: {all_results}")
all_results = [item for sublist in all_results for item in sublist]
log_results(all_results, "")


