import os

import torch
import numpy as np

import matplotlib.pyplot as plt


def normalize_trials(value_list):
    # value_list = np.array([t.values for t in trials])
    mu = np.mean(value_list, axis=0)
    std = np.std(value_list, axis=0)
    return (value_list - mu) / (std + 1e-9)


def radar_area_matrix(values):
    """
    Computes the radar-chart polygon area for each row in 'values'
    using the shoelace formula (vectorized).

    Parameters
    ----------
    values : 2D np.ndarray of shape (m, n)
        - m: number of rows (models/configurations)
        - n: number of radar dimensions (objectives)
        Each row is one set of radial distances for the radar chart.

    Returns
    -------
    areas : 1D np.ndarray of shape (m,)
        The polygon area for each row.
    """
    
    values = np.asarray(values)
    values = normalize_trials(values)
    m, n = values.shape

    # Angles equally spaced
    angles = np.linspace(0, 2 * np.pi, n, endpoint=False)

    x = values * np.cos(angles)
    y = values * np.sin(angles)

    # Shoelace formula
    cross_term = x * np.roll(y, -1, axis=1) - y * np.roll(x, -1, axis=1)
    areas = 0.5 * np.abs(np.sum(cross_term, axis=1))

    return areas


if __name__ == "__main__":
    root = '/mnt/data01/public/aad_data'
    method = 'TSNE'
    for dataset in [
        # 'mnist', 'fmnist',
        'PBMC68K', 'Campbell', 'Mouse_retina',
        'Baron Human'
                     ]: #'cifar10'

        file_names = os.listdir(root + '/gene_filtered/' + dataset)
        file_names = ['.'.join(name.split('.')[:-1]) for name in file_names]
        for d_name in file_names:
            # save = torch.load(f'{root}/bo/{dataset}/data/visual-method-{method}_dataset-{d_name}', weights_only=False)
            save = torch.load(f'{root}/bo/gene_filtered/{dataset}/data/visual-method-{method}_dataset-{d_name}', weights_only=False)
            scores = save['scores']
    ##        score_area = radar_area_matrix(scores)
            selected_idx = np.argmax(scores)

            selected_emb = save['embs'][selected_idx]
            selected_hp = save['hps'][selected_idx]

            print(scores[selected_idx], selected_hp)

            torch.save((selected_emb, selected_hp), f'{root}/bo/gene_filtered/{dataset}/visual-method-{method}_dataset-{d_name}_selected_emb.tar')



