import numpy as np
import ot
from tqdm import tqdm
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import KNeighborsClassifier
import scipy.io

rs = 0
np.random.seed(rs)
print('rs: {}'.format(rs))

def get_acc(P, X1, Y1, X2, Y2, X2_test, Y2_test):
    weights=np.sum(P, axis = 1)
    X1_proj = np.matmul(P, X2) / weights[:, None]

    knn = KNeighborsClassifier(n_neighbors=1)
    knn.fit(X1_proj, Y1)
    pred = knn.predict(X2_test)
    acc = (pred == Y2_test).mean()
    print('[!] K: 1 Transfer Acc: {}'.format(acc))

def ratio(P, K_x, K_y):
    f_x = K_x.sum(1) / K_x.shape[1]
    f_y = K_y.sum(1) / K_y.shape[1]
    f_x_f_y = np.outer(f_x, f_y)
    constC = np.zeros((len(K_x), len(K_y)))
    # there's a negative sign in ot.gromov.tensor_product
    f_xy = -ot.gromov.tensor_product(constC, K_x, K_y, P)
    r = f_xy / f_x_f_y
    return r

def migrad(P, K_x, K_y):
    f_x = K_x.sum(1) / K_x.shape[1]
    f_y = K_y.sum(1) / K_y.shape[1]
    f_x_f_y = np.outer(f_x, f_y)
    constC = np.zeros((len(K_x), len(K_y)))
    # there's a negative sign in ot.gromov.tensor_product
    f_xy = -ot.gromov.tensor_product(constC, K_x, K_y, P)
    P_f_xy = P / f_xy
    P_grad = -ot.gromov.tensor_product(constC, K_x, K_y, P_f_xy)
    P_grad = np.log(f_xy / f_x_f_y) + P_grad
    return -P_grad

def dist(z1, z2):
    x1, x2 = z1[:-1], z2[:-1]
    y1, y2 = z1[-1], z2[-1]
    if y1 != y2:
        return np.linalg.norm(x1 - x2) + 5000
    else:
        return np.linalg.norm(x1 - x2)


mat1 = scipy.io.loadmat('decaf6/caltech_decaf.mat')
mat2 = scipy.io.loadmat('decaf6/dslr_decaf.mat')

X1 = mat1['feas']
Y1 = mat1['labels'].reshape(-1)
X2 = mat2['feas']
Y2 = mat2['labels'].reshape(-1)
Z1 = np.concatenate((X1, Y1.reshape(-1, 1)), axis=1)
Z2 = np.concatenate((X2, Y2.reshape(-1, 1)), axis=1)

# random shuffle target data
idx = np.array(range(len(X2)))
np.random.shuffle(idx)
X2, Y2 = X2[idx], Y2[idx]

X2_train = X2[:int(len(X2)*0.9)]
Y2_train = Y2[:int(len(X2)*0.9)]
X2_test = X2[int(len(X2)*0.9):]
Y2_test = Y2[int(len(X2)*0.9):]

# KDE
C = pairwise_distances(X1, X2_train)
C1 = pairwise_distances(Z1, Z1, metric=dist)
C2 = pairwise_distances(X2_train, X2_train)

std1 = np.sqrt((C1**2).mean() / 2)
std2 = np.sqrt((C2**2).mean() / 2)
h1 = 0.5 * std1
h2 = 0.5 * std2

# Gaussian
K1 = (np.exp(-(C1 / h1)**2 / 2) / ((2 * np.pi)**0.5)) / h1
K2 = (np.exp(-(C2 / h2)**2 / 2) / ((2 * np.pi)**0.5)) / h2

# Info OT
p = np.zeros(len(X1)) + 1. / len(X1)
q = np.zeros(len(X2_train)) + 1. / len(X2_train)
P = np.outer(p, q) 
lam = 100
for i in tqdm(range(50)):
    grad_P = migrad(P, K1, K2)
    C_ = C + lam * grad_P
    P = ot.bregman.sinkhorn(p, q, C_, reg=1)

print('InfoOT Barycentric Proj')
get_acc(P, X1, Y1, X2_train, Y2_train, X2_test, Y2_test)

print('InfoOT Conditional Proj')
get_acc(ratio(P, K1, K2), X1, Y1, X2_train, Y2_train, X2_test, Y2_test)


