# %%
import sys
from pathlib import Path
# Get the parent directory of the current file
parent_dir = Path(__file__).resolve().parent.parent
# Add the parent directory to sys.path
sys.path.insert(0, str(parent_dir))
import numpy as np
from cblearn.datasets import triplet_response
from cblearn.metrics import query_accuracy
from cblearn.embedding import SOE
from cblearn.embedding import FORTE, TSTE, CKL
from lore.triplet_algorithm import train_lore
from lore.f import LogisticTripletLoss
from lore.g import SchattenPUnnormed
# Make the test set
import torch
import time
import pandas as pd
import os
import time
import json
from pqdm.processes import pqdm

# %%
NUM_PERCEPTS = 100
EMBEDDING_DIM = 15
p=0.5
lamb = 0.01
MAX_ITERATIONS = 1000
NUM_WORKERS = 5

# %%
import pickle
print(f"Current working directory: {os.getcwd()}")
food_100_data = pickle.load(open("real_data/food_100_complete.pkl", "rb"))

# %%
def make_food_train_data(seed):
    all_data = food_100_data["data"]
    print(f"Total number of triplets: {len(all_data)}")
    # number of unique items
    unique_items = np.unique(all_data)
    print(f"Number of unique items: {len(unique_items)}")
    np.random.seed(seed)
    np.random.shuffle(all_data)
    train_data = all_data[:int(0.9 * len(all_data))]
    test_data = all_data[int(0.9 * len(all_data)):]
    return train_data, test_data

def train_soe(init_embedding, noisy_train_response, seed):
    device = torch.device("cuda")
    embedder = SOE(n_components=EMBEDDING_DIM, backend='torch', n_init=1, restart_optim=1,  random_state=seed)
    embedder.embedding_ = init_embedding
    embedding = embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)
    return embedding

def train_forte(init_embedding, noisy_train_response, seed):
    forte_embedder = FORTE(n_components=EMBEDDING_DIM,random_state=seed)
    forte_embedder.embedding_ = init_embedding
    forte_embedding = forte_embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)
    return forte_embedding

def train_tste(init_embedding, noisy_train_response, seed):
    tste_embedder = TSTE(n_components=EMBEDDING_DIM,backend='torch',random_state=seed)
    tste_embedder.embedding_ = init_embedding
    tste_embedding = tste_embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)
    return tste_embedding

def train_ckl(init_embedding, noisy_train_response, seed):
    ckl_embedder = CKL(n_components=EMBEDDING_DIM, mu=1,backend='torch', random_state=seed)
    ckl_embedder.embedding_ = init_embedding
    ckl_embedding = ckl_embedder.fit_transform(noisy_train_response, n_objects=NUM_PERCEPTS)
    return ckl_embedding

def train_lore_algorithm(init_embedding, noisy_train_response, lamb, seed):
    device = torch.device("cuda")
    torch.manual_seed(1)
    X = torch.tensor(init_embedding, device=device, dtype=torch.float32)
    # Initialize loss functions
    f = LogisticTripletLoss(triplets=torch.from_numpy(noisy_train_response).long().to(device), margin=1)
    g = SchattenPUnnormed(p)
    # Run PIRNN algorithm
    results=train_lore(X, lamb, f, g, mu=0.1, max_iterations=MAX_ITERATIONS, tol=1e-6, seed=seed, zero=1e-15, verbose=False)
    return results["X"]

def get_metrics(embedding, test_response):
    predicted_response = triplet_response(test_response, embedding)
    accuracy = query_accuracy(test_response, predicted_response)
    rank = np.linalg.matrix_rank(embedding)
    metrics = {
        "test_accuracy": accuracy,
        "rank": rank,
        "embedding": json.dumps(embedding.tolist())
    }
    return metrics

# %%
def train_one(seed):
    all_metrics = []
    train_data, test_data = make_food_train_data(seed)
    np.random.seed(seed)
    X = np.random.randn(NUM_PERCEPTS, EMBEDDING_DIM)
    print(f"Made the data for seed {seed}")
    # Train all algorithms
    print("Starting SOE")
    # SOE
    start_time = time.time()
    init_embedding = X.copy()
    soe_embedding = train_soe(init_embedding, train_data, seed+1)
    soe_time = time.time() - start_time
    soe_metrics = get_metrics(soe_embedding, test_data)
    soe_metrics["time"] = soe_time
    soe_metrics["algorithm"] = "SOE"
    soe_metrics["seed"] = seed
    print(f"SOE time: {soe_time}")
    
    # LORE
    start_time = time.time()
    init_embedding = X.copy()
    lore_embedding = train_lore_algorithm(init_embedding, train_data, lamb, seed+2)
    lore_time = time.time() - start_time
    lore_metrics = get_metrics(lore_embedding, test_data)
    lore_metrics["time"] = lore_time
    lore_metrics["algorithm"] = "LORE"
    lore_metrics["seed"] = seed
    print(f"LORE time: {lore_time}")

    # FORTE
    start_time = time.time()
    init_embedding = X.copy()
    forte_embedding = train_forte(init_embedding, train_data, seed+3)
    forte_time = time.time() - start_time
    forte_metrics = get_metrics(forte_embedding, test_data)
    forte_metrics["time"] = forte_time
    forte_metrics["algorithm"] = "FORTE"
    forte_metrics["seed"] = seed
    print(f"FORTE time: {forte_time}")

    # TSTE
    start_time = time.time()
    init_embedding = X.copy()
    tste_embedding = train_tste(init_embedding, train_data, seed+4)
    tste_time = time.time() - start_time
    tste_metrics = get_metrics(tste_embedding, test_data)
    tste_metrics["time"] = tste_time
    tste_metrics["algorithm"] = "TSTE"
    tste_metrics["seed"] = seed
    print(f"TSTE time: {tste_time}")

    # CKL
    start_time = time.time()
    init_embedding = X.copy()
    ckl_embedding = train_ckl(init_embedding, train_data, seed+5)
    ckl_time = time.time() - start_time
    ckl_metrics = get_metrics(ckl_embedding, test_data)
    ckl_metrics["time"] = ckl_time
    ckl_metrics["algorithm"] = "CKL"
    ckl_metrics["seed"] = seed
    print(f"CKL time: {ckl_time}")

    # append all metrics to the list
    all_metrics.append(soe_metrics)
    all_metrics.append(lore_metrics)
    all_metrics.append(forte_metrics)
    all_metrics.append(tste_metrics)
    all_metrics.append(ckl_metrics)
    return all_metrics

def log_results(all_metrics, output_dir):
    # Create a DataFrame from the metrics
    df = pd.DataFrame(all_metrics)
    # Save the DataFrame to a CSV file
    filename = "results/food100/food100.csv"
    df.to_csv(os.path.join(output_dir, filename), index=False)


# %%
run_ids = list(range(1, 31))
# run_ids = [1]
all_results = []
# for run_id in tqdm(run_ids):
#     results = train_one(run_id)
#     all_results.extend(results)
# run_ids = [1,2]
all_results = pqdm(run_ids, train_one, n_jobs=NUM_WORKERS, desc="Run number", unit='run')


# %%
print(all_results)

# %%
all_results = [item for sublist in all_results for item in sublist]
log_results(all_results, "")



