import time
import numpy as np
import torch
from torch.nn import functional as F
from sklearn.linear_model import LogisticRegression, SGDClassifier
from utils.clustering import DIST
    
def logistic_metrics(source_dataset, target_dataset, reverse=False):
    features, outputs, labels = source_dataset
    if reverse:
        labels = torch.argmax(outputs, 1)
    features, labels = features.numpy(), labels.numpy()
    if features.shape[0] > 10000:
        k = int(features.shape[0] / 10000)
        features = features[::k]
        labels = labels[::k]
    start_time = time.time()
    # lin_model = SGDClassifier(loss='log_loss', alpha=1, max_iter=1000, 
    #                           tol=1e-3, n_jobs=-1, early_stopping=True)
    lin_model = LogisticRegression(max_iter=200, solver='lbfgs', multi_class='multinomial', 
                                    tol=1e-3, n_jobs=-1)
    lin_model.fit(features, labels)
    print(f"logistic_metrics computing time: {time.time()-start_time}")
    source_acc = lin_model.score(features, labels)
    features, outputs, labels = target_dataset
    prob = lin_model.predict_proba(features.numpy())
    entropy = - (-torch.log(prob+1e-5)*prob).sum(0).mean()
    inception_score = entropy + (-np.log(prob.mean(1)+1e-5)*prob.mean(1)).sum()
    acc = lin_model.score(features.numpy(), labels.numpy())
    return entropy, inception_score, source_acc, acc


from sklearn.neural_network import MLPClassifier

def mlp_metrics(source_dataset, target_dataset, reverse=False, hidden_layer_sizes=(512,), 
                early_stopping=True, solver='lbfgs', pred_input=False, return_clf=False):
    features, outputs, labels = source_dataset
    if reverse:
        labels = torch.argmax(outputs, 1)
    if pred_input:
        features = F.softmax(outputs, dim=1)
    features, labels = features.numpy(), labels.numpy()
    if features.shape[0] > 10000:
        k = int(features.shape[0] / 10000)
        features = features[::k]
        labels = labels[::k]
    start_time = time.time()
    if solver=='lbfgs':
        if early_stopping:
            clf = MLPClassifier(hidden_layer_sizes=hidden_layer_sizes, solver='lbfgs', alpha=1.0,
                                    max_iter=200, tol=1e-3, early_stopping=True)
        else:
            clf = MLPClassifier(hidden_layer_sizes=hidden_layer_sizes, solver='lbfgs', alpha=1.0,
                                    max_iter=1000, tol=1e-4, early_stopping=False)
    if solver=='adam':
        clf = MLPClassifier(hidden_layer_sizes=hidden_layer_sizes, solver='adam', alpha=1e-4,
                            max_iter=200, early_stopping=early_stopping)
    clf.fit(features, labels)
    print(f"mlp_metrics computing time: {time.time()-start_time}")
    source_acc = clf.score(features, labels)
    features, outputs, labels = target_dataset
    if pred_input:
        features = F.softmax(outputs, dim=1)
    prob = clf.predict_proba(features.numpy())
    entropy = - (-np.log(prob+1e-5)*prob).sum(1).mean()
    inception_score = entropy + (-np.log(prob.mean(0)+1e-5)*prob.mean(0)).sum()
    acc = clf.score(features.numpy(), labels.numpy())
    if return_clf:
        return entropy, inception_score, source_acc, acc, clf
    else:
        return entropy, inception_score, source_acc, acc