import torch
import numpy as np
from sklearn import svm
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment

# SVM linear classification.
def classify(model, labels, view1, view1_valid, view1_test, device):
    with torch.no_grad():
        print("Linear SVM Classification")

        best_error_tune=1.0

        for c in [0.1, 1.0, 10.0]:
            lin_clf=svm.SVC(C=c, kernel="linear")
            # train
            svm_x_sample=torch.tensor(view1).unsqueeze(1).to(device)
            svm_y_sample=labels['train']
            s, p = model.encode([svm_x_sample, svm_x_sample])
            svm_z_sample = s[0].cpu()

            lin_clf.fit(svm_z_sample.numpy(), svm_y_sample)

            # valid
            svm_x_sample=torch.tensor(view1_valid).unsqueeze(1).to(device)
            svm_y_sample=labels['valid']
            s, p = model.encode([svm_x_sample, svm_x_sample])
            svm_z_sample = s[0].cpu()

            pred=lin_clf.predict(svm_z_sample.numpy())
            svm_error_tune=np.mean(pred != svm_y_sample)
            print("c=%f, valid error %f" % (c, svm_error_tune))
            if svm_error_tune < best_error_tune:
                best_error_tune=svm_error_tune
                bestsvm=lin_clf

        # test
        svm_x_sample = torch.tensor(view1_test).unsqueeze(1).to(device)
        svm_y_sample = labels['test']
        s, p = model.encode([svm_x_sample, svm_x_sample])
        svm_z_sample = s[0].cpu()

        pred=bestsvm.predict(svm_z_sample.numpy())
        best_error_test=np.mean(pred != svm_y_sample)
        print("validerr=%f, testerr=%f" % (best_error_tune, best_error_test))

# Evalute clustering accuracy with Hungarian algorithm
def eval_kmeans_acc(pred, gt):
    num_c = len(np.unique(gt))
    ngt = np.zeros_like(gt)
    for i,n in enumerate(np.unique(gt)):
        ngt[np.where(gt==n)]=i

    gt = ngt

    mat = np.zeros((num_c, num_c))

    for i in range(num_c):
        idx = np.where(pred==i)
        r_idx = np.where(pred!=i)
        for j in range(num_c):
            tmp_p = np.array(pred)
            tmp_p[idx[0]] = j
            tmp_p[r_idx[0]] = -1
            mat[i,j] = np.sum(np.equal(tmp_p, gt))

    row_ind, col_ind = linear_sum_assignment(-1.*mat)

    best_acc = np.sum(mat[row_ind,col_ind])/len(gt)

    return best_acc

# K-means clustering
def cluster(model, labels, view1, view1_valid, view1_test, device):
    with torch.no_grad():
        print("K-means Clustering")

        rstate=0
        train_label = labels['train']
        valid_label = labels['valid']
        test_label = labels['test']

        x = torch.tensor(view1).unsqueeze(1).to(device)
        x_v = torch.tensor(view1_valid).unsqueeze(1).to(device)
        x_t = torch.tensor(view1_test).unsqueeze(1).to(device)

        s, _ = model.encode([x, x])
        s_v, _ = model.encode([x_v, x_v])
        s_t, _ = model.encode([x_t, x_t])

        clustering = KMeans(n_clusters=10, random_state=rstate).fit(
                np.vstack((s[0].cpu().numpy(), s_v[0].cpu().numpy(), s_t[0].cpu().numpy())))

        pred = clustering.labels_[-len(test_label):]

        acc = eval_kmeans_acc(pred, test_label)

        print('Kmeans acc={}'.format(acc))
