import numpy as np
import torch

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV

from sklearn.exceptions import ConvergenceWarning

from warnings import simplefilter
simplefilter("ignore", category=ConvergenceWarning)

"""================================================================================================="""
HAS_OPTIONS = False
OPTIONS = None

class Metric():
    def __init__(self, aggfunc = 'mean', **kwargs):
        self.requires = ['features_cosine', 'target_labels']
        assert aggfunc in ['mean', 'worst']
        self.name     = f'lineval_roc_{aggfunc}'
        self.aggfunc = aggfunc

    def _train_lr(self, x, y):
        return GridSearchCV(estimator = LogisticRegression(random_state = 42),
                            param_grid = {
                                'C': 10**np.linspace(-5, 1, 20),
                                'multi_class': ['ovr', 'multinomial']
                            }, n_jobs = -1, cv = 3, scoring = 'roc_auc_ovr').fit(x, y.ravel())

    def __call__(self, features_cosine, target_labels, **kwargs):
        if torch.is_tensor(features_cosine):
            features = features_cosine.detach().cpu().numpy()
        if torch.is_tensor(target_labels):
            target_labels = target_labels.detach().cpu().numpy()

        if target_labels.ndim == 1 or target_labels.shape[1] == 1:
            model = self._train_lr(features, target_labels)
            return model.best_score_
        else:
            scores = []
            for label_idx in range(target_labels.shape[1]):
                scores.append(self._train_lr(features, target_labels[:, label_idx]).best_score_)
            
            if self.aggfunc == 'mean':
                return np.mean(scores)
            elif self.aggfunc == 'worst':
                return np.min(scores)

