import os

import numpy as np
import torch

import optuna

from BO_Objective import Objective


def normalize_trials(trials):
    value_list = np.array([t.values for t in trials])
    mu = np.mean(value_list, axis=0)
    std = np.std(value_list, axis=0)
    return (value_list - mu) / (std + 1e-9)


def radar_area_matrix(values):
    """
    Computes the radar-chart polygon area for each row in 'values'
    using the shoelace formula (vectorized).

    Parameters
    ----------
    values : 2D np.ndarray of shape (m, n)
        - m: number of rows (models/configurations)
        - n: number of radar dimensions (objectives)
        Each row is one set of radial distances for the radar chart.

    Returns
    -------
    areas : 1D np.ndarray of shape (m,)
        The polygon area for each row.
    """
    values = np.asarray(values)
    m, n = values.shape

    # Angles equally spaced
    angles = np.linspace(0, 2 * np.pi, n, endpoint=False)

    x = values * np.cos(angles)
    y = values * np.sin(angles)

    # Shoelace formula
    cross_term = x * np.roll(y, -1, axis=1) - y * np.roll(x, -1, axis=1)
    areas = 0.5 * np.abs(np.sum(cross_term, axis=1))

    return areas




if __name__ == "__main__":
    datasets = [
        'Baron Human',
        'Mouse_retina',
        'PBMC68K',
        'Campbell',
                ]
    root = '/mnt/data01/public/aad_data/gene_filtered_90pca'
    for dataset in datasets:
        read_ = torch.load(f'/mnt/data01/public/aad_data/gene_filtered_90pca/{dataset}_extracted.tar', weights_only=False)
        features, labels = read_[:2]
        # unique, counts = np.unique(labels, return_counts=True)
        # to_keep = unique[counts >= labels.shape[0] * 0.03]
        # print(to_keep)
        # # Filter the array
        # features = features[np.isin(labels, to_keep)]
        # labels = labels[np.isin(labels, to_keep)]

        for f_name in os.listdir(root + '/' + dataset):
            dataset_path = f'{dataset}/{f_name}'
            print(dataset_path)
            if os.path.exists(f"/mnt/data01/public/aad_data/bo/gene_filtered_90pca/{dataset}/data/visual-method-TSNE_dataset-{f_name.split('.')[0]}"):
                print('e' + f_name)
                continue
            objective = Objective(dataset_name=f'{f_name.split(".")[0]}',
                                  dataset_feature_name=f'{root}/{dataset}_extracted.tar',
                                  method='tsne',
                                  dataset_path=root + '/' + dataset_path,
                                  save_path=f'/mnt/data01/public/aad_data/bo/gene_filtered_90pca/{dataset}',
                                  features_n_labels=(features, labels ))
            if len(set(objective.y)) < 2:
                print(set(objective.y))
                continue
            # if objective.x.shape[0] > 2500:
            #     continue
            algo = optuna.samplers.TPESampler(n_startup_trials=5, n_ei_candidates=12)
            # algo = optuna.samplers.NSGAIISampler()
            # algo = optuna.samplers.GPSampler(n_startup_trials=10)
            study = optuna.create_study(sampler=algo, direction='maximize', storage="sqlite:///record/ocsvm2.db")
            study.optimize(objective.tsne_search, n_trials=40, show_progress_bar=True)
            print(study.best_trials)

            # objective.plot_results('OCSVM', root='./res-ocsvm-large_data')
            objective.save_result('TSNE')
