# %%
# Imports
# import all necessary packages
from cblearn.datasets import LinearSubspace
import numpy as np
from cblearn.datasets import make_random_triplet_indices, noisy_triplet_response
from cblearn.datasets import triplet_response
from cblearn.metrics import query_accuracy
from cblearn.embedding import SOE, FORTE, CKL, TSTE
from lore.triplet_algorithm import train_lore
from lore.f import LogisticTripletLoss
from lore.g import SchattenPUnnormed
from utils.test_set import sample_test_set
from utils.psnr import compute_psnr, normalized_procrustes_distance
import torch
import pandas as pd
import math
import time
import sklearn
import os
from pqdm.processes import pqdm
from scipy.spatial.distance import pdist, squareform
from pathlib import Path

parent_dir = Path().resolve().parent

# %%
NUM_PERCEPTS = [1000, 500, 100, 50]
# NUM_PERCEPTS = [50]
NUM_TRIPLETS = [10000, 5000, 1000, 500, 100]
TEST_SET_SIZE = 1000
EMBEDDING_DIMS = [50, 30, 15, 5]
# EMBEDDING_DIMS = [10]
INTRINSIC_RANK = 5
NOISE = 0.1
p=0.5
lamb=0.01
MAX_ITERATIONS = 1000
NUM_RUNS = 30
NUM_WORKERS = 5

# %%
def get_true_points(N, seed):
    manifold = LinearSubspace(subspace_dimension=INTRINSIC_RANK, space_dimension=15, random_state=seed)
    true_points, _ = manifold.sample_points(N, random_state=seed*69)
    # Normalize points by frobenius norm
    true_points = true_points / np.linalg.norm(true_points, 'fro')
    # Get new true distances
    true_distances = squareform(pdist(true_points, metric='euclidean'))
    return true_points, true_distances

# %%
def get_train_triplets(N, true_distances, NUM_TRAIN_TRIPLETS, seed):
    train_triplets = make_random_triplet_indices(n_objects=N, size=NUM_TRAIN_TRIPLETS, repeat=True, random_state=seed)
    if NOISE <=0:
        noisy_train_response = triplet_response(train_triplets, true_distances, distance='precomputed')
    else:
        noisy_train_response = noisy_triplet_response(train_triplets, true_distances, distance='precomputed', noise='normal', noise_options={'scale': NOISE}, random_state=seed)
    return train_triplets, noisy_train_response

# %%
def get_true_test_response(N, train_triplets, true_distances, seed):
    test_set = sample_test_set(train_triplets, N, TEST_SET_SIZE, seed=seed)
    true_test_response = triplet_response(test_set, true_distances, distance='precomputed')
    return test_set, true_test_response

# %%
def get_metrics(embedding, true_points, test_set, true_test_response):
    rank = np.linalg.matrix_rank(embedding)
    predicted_test_response = triplet_response(test_set, embedding)
    # Query accuracy
    test_accuracy = query_accuracy(true_test_response, predicted_test_response)
    # Procrustes Distance
    procrustes_distance = normalized_procrustes_distance(true_points, embedding)
    # PSNR
    psnr= compute_psnr(true_points, embedding)
    metrics = {
        'rank': rank,
        'test_accuracy': test_accuracy,
        'procrustes_distance': procrustes_distance,
        'psnr': psnr
    }
    return metrics

# %%
def train_forte(N, embedding_dim, noisy_train_response, seed):
    forte_embedder = FORTE(n_components=embedding_dim,random_state=seed)
    forte_embedding = forte_embedder.fit_transform(noisy_train_response, n_objects=N)
    return forte_embedding

def train_tste(N, embedding_dim, noisy_train_response, seed):
    tste_embedder = TSTE(n_components=embedding_dim,backend='torch',random_state=seed)
    tste_embedding = tste_embedder.fit_transform(noisy_train_response, n_objects=N)
    return tste_embedding

def train_ckl(N, embedding_dim, noisy_train_response, seed):
    ckl_embedder = CKL(n_components=embedding_dim, mu=1,backend='torch', random_state=seed)
    ckl_embedding = ckl_embedder.fit_transform(noisy_train_response, n_objects=N)
    return ckl_embedding

def train_soe(N, embedding_dim, noisy_train_response, seed):
    device = torch.device("cuda")
    embedder = SOE(n_components=embedding_dim, backend='torch', n_init=1, restart_optim=1,  random_state=seed)
    embedding = embedder.fit_transform(noisy_train_response, n_objects=N)
    return embedding

def train_lore_algorithm(N, embedding_dim, noisy_train_response, lamb, seed):
    device = torch.device("cuda")
    torch.manual_seed(1)
    X = torch.randn(N, embedding_dim, device=device, requires_grad=True)
    # Initialize loss functions
    f = LogisticTripletLoss(triplets=torch.from_numpy(noisy_train_response).long().to(device), margin=1)
    g = SchattenPUnnormed(p)
    # Run PIRNN algorithm
    results=train_lore(X, lamb, f, g, mu=0.1, max_iterations=MAX_ITERATIONS, tol=1e-6, seed=seed, zero=1e-15, verbose=False)

    return results["X"]

# %%
def train_one_combination( EMBEDDING_DIM, N, NUM_TRAIN_TRIPLETS, seed):
    all_metrics = []
    # Get true points and distances
    true_points, true_distances = get_true_points(N, seed)
    # Get training triplets and responses
    train_triplets, noisy_train_response = get_train_triplets(N, true_distances, NUM_TRAIN_TRIPLETS, seed)
    # Get test set and true test response
    test_set, true_test_response = get_true_test_response(N, train_triplets, true_distances, seed=seed)
    X = np.random.randn(N, EMBEDDING_DIM)

    # Train SOE
    start_time = time.time()
    soe_embedding = train_soe(N, EMBEDDING_DIM, noisy_train_response.copy(), seed*20)
    soe_time = time.time() - start_time
    soe_metrics = get_metrics(soe_embedding, true_points, test_set, true_test_response)
    soe_metrics['embedding_dim'] = EMBEDDING_DIM
    soe_metrics['method'] = 'SOE'
    soe_metrics['seed'] = seed
    soe_metrics["num_train_triplets"]= NUM_TRAIN_TRIPLETS
    soe_metrics['num_percepts'] = N
    soe_metrics['lambda'] = 0
    soe_metrics['time'] = soe_time
    all_metrics.append(soe_metrics)

    # Train FORTE
    start_time = time.time()
    forte_embedding = train_forte(N, EMBEDDING_DIM, noisy_train_response.copy(), seed*20)
    forte_time = time.time() - start_time
    forte_metrics = get_metrics(forte_embedding, true_points, test_set, true_test_response)
    forte_metrics['embedding_dim'] = EMBEDDING_DIM
    forte_metrics['method'] = 'FORTE'
    forte_metrics['seed'] = seed
    forte_metrics["num_train_triplets"]= NUM_TRAIN_TRIPLETS
    forte_metrics['num_percepts'] = N
    forte_metrics['lambda'] = 0
    forte_metrics['time'] = forte_time
    all_metrics.append(forte_metrics)

    # Train TSTE
    start_time = time.time()
    tste_embedding = train_tste(N, EMBEDDING_DIM, noisy_train_response.copy(), seed*20)
    tste_time = time.time() - start_time
    tste_metrics = get_metrics(tste_embedding, true_points, test_set, true_test_response)
    tste_metrics['embedding_dim'] = EMBEDDING_DIM
    tste_metrics['method'] = 'TSTE'
    tste_metrics['seed'] = seed
    tste_metrics["num_train_triplets"]= NUM_TRAIN_TRIPLETS
    tste_metrics['num_percepts'] = N
    tste_metrics['lambda'] = 0
    tste_metrics['time'] = tste_time
    all_metrics.append(tste_metrics)

    # Train CKL
    start_time = time.time()
    ckl_embedding = train_ckl(N, EMBEDDING_DIM, noisy_train_response.copy(), seed*20)
    ckl_time = time.time() - start_time
    ckl_metrics = get_metrics(ckl_embedding, true_points, test_set, true_test_response)
    ckl_metrics['embedding_dim'] = EMBEDDING_DIM
    ckl_metrics['method'] = 'CKL'
    ckl_metrics['seed'] = seed
    ckl_metrics["num_train_triplets"]= NUM_TRAIN_TRIPLETS
    ckl_metrics['num_percepts'] = N
    ckl_metrics['lambda'] = 1
    ckl_metrics['time'] = ckl_time
    all_metrics.append(ckl_metrics)
        
    # Train LORE
    start_time = time.time()
    lore_embedding = train_lore_algorithm(N, EMBEDDING_DIM, noisy_train_response.copy(), lamb, seed*30)
    lore_time = time.time() - start_time
    # print singular values
    lore_metrics = get_metrics(lore_embedding, true_points, test_set, true_test_response)
    lore_metrics['embedding_dim'] = EMBEDDING_DIM
    lore_metrics['lambda'] = lamb
    lore_metrics['method'] = 'LORE'
    lore_metrics['seed'] = seed
    lore_metrics["num_train_triplets"]= NUM_TRAIN_TRIPLETS
    lore_metrics['num_percepts'] = N
    lore_metrics['time'] = lore_time
    all_metrics.append(lore_metrics)
    return all_metrics

# %%
run_ids = list(range(1, NUM_RUNS + 1))
param_grid = {'N': NUM_PERCEPTS, 'NUM_TRAIN_TRIPLETS': NUM_TRIPLETS, 'EMBEDDING_DIM': EMBEDDING_DIMS, 'seed': run_ids}
param_combinations = list(sklearn.model_selection.ParameterGrid(param_grid))
print(param_combinations)

# %%
# for param_combination in param_combinations:
#     N = param_combination['N']
#     NUM_TRAIN_TRIPLETS = param_combination['NUM_TRAIN_TRIPLETS']
#     EMBEDDING_DIM = param_combination['EMBEDDING_DIM']
#     seed = param_combination['seed']
    
#     print(f"Running for N={N}, NUM_TRAIN_TRIPLETS={NUM_TRAIN_TRIPLETS}, EMBEDDING_DIM={EMBEDDING_DIM}, seed={seed}")
    
#     all_results = train_one_combination(EMBEDDING_DIM, N, NUM_TRAIN_TRIPLETS, seed)
    
#     break

# %%
all_results = pqdm(param_combinations, train_one_combination, n_jobs=NUM_WORKERS, desc="Param Combos", unit='combo', argument_type='kwargs')

# %%
def log_results(all_metrics, output_dir):
    # Create a DataFrame from the metrics
    df = pd.DataFrame(all_metrics)
    # Save the DataFrame to a CSV file
    filename = f"{parent_dir}/results/scaling/scaling2.csv"
    df.to_csv(os.path.join(output_dir, filename), index=False)



# %%
all_results

# %%
print(f"All results: {all_results}")
all_results = [item for sublist in all_results for item in sublist]
log_results(all_results, "")


