import numpy as np
import cupy as cp
import scipy as sp
import scipy.linalg as la
import time
import os, sys
from sklearn.linear_model import LogisticRegression
from utils_solver import *

def cal_acc(y,yex): return 100 *(1 - np.sum(y!=yex)/yex.shape[0])

def FIRAL(X, Y, idx_sel, c, b, eta, tol):

    (n,d) = X.shape

    # prepare 
    cl = LogisticRegression(penalty = 'l2',C=1.0, class_weight='balanced',random_state = 0, fit_intercept=False, multi_class="multinomial", max_iter = 1000).fit(X[idx_sel], Y[idx_sel])
    pre_model = cl.predict(X)
    s = cal_acc(pre_model, Y)
    print('initial accuracy:', s)

    W = cl.coef_
    gamma = cl.predict_proba(X)
    gamma = np.delete(gamma, -1, axis = 1)
    G = [np.diag(gamma[i]) - np.outer(gamma[i], gamma[i].T) for i in range(n)]
    G = np.asarray(G)
    w, v = np.linalg.eigh(G)
    neg_id = (w<0)
    w[neg_id] = 0.0

    w = np.sqrt(w)
    Q = np.empty((n, c-1, c-1))
    Q = np.einsum('ijk,ik->ijk', v, w)
    P = [np.kron(X[i][:,None], Q[i]) for i in range(n)]
    P = np.asarray(P)
    D = np.einsum('ijk,kli->jl', P[idx_sel], P[idx_sel].transpose((2,1,0)),optimize=True)

    P = cp.asarray(P)
    I_U = cp.einsum('ijk,kli->jl', P, P.transpose((2,1,0)), optimize=True)
    I_U *= (1./n)

    # relax solve
    print("\n\n =================solving relaxed convex optimization problem===============")
    idx_pool = np.delete(np.arange(n), idx_sel)
    P = P[idx_pool]
    D = cp.asarray(D)
    n_pool = n - len(idx_sel)
    c1 = c-1

    solver = RelaxSolver(n_pool, d, c1, P, I_U, D)
    itr_max = 2000
    print_unit = 200
    l = 5.-5
    pi = solver.run(l, b, itr_max, print_unit, tol)
    pi *= b

    # sparsification solve
    print("\n\n =================solving sparsification problem===============")
    D *= (1./b)
    pi = cp.asarray(pi)
    S = cp.einsum('i,ijk,kli->jl', pi, P, P.transpose((2,1,0)), optimize=True)
    S += b*D
    w,v= cp.linalg.eigh(S)
    w = cp.power(w,-0.5)
    S = v @ cp.diag(w) @ v.T
    Dtilde = S @ D @ S

    solver = SolverRound(Dtilde, P, S)
    sel = solver.run_l12(b, eta, print_unit=10)
    sel = np.asarray(sel)
    sel_ = idx_pool[sel]
    idx = np.concatenate((idx_sel, sel_))

    return idx


def uncertainty(X,Y, idx_sel, b, method):
    (n,d) = X.shape
    pool = np.delete(np.arange(n), idx_sel)
    n_ = len(pool)

    cl = LogisticRegression(penalty = 'l2',class_weight='balanced',random_state = 0, fit_intercept=False, multi_class="multinomial", max_iter = 1000).fit(X[idx_sel], Y[idx_sel])
    p = cl.predict_proba(X[pool])

    if method == 'entropy':
        a = [sp.stats.entropy(p[i]) for i in range(n_)]
        a = np.asarray(a)
    if method == 'varratios':
        a = np.max(p, axis=1)
        a = 1.-a
    sel_ = np.argsort(a)
    sel_= sel_[- b:]

    idx = np.concatenate((idx_sel, pool[sel_]))
    return idx

def BAIT(X,Y, idx_sel, c, b):

    (n,d) = X.shape
    # prepare
    cl = LogisticRegression(penalty = 'l2',C=1.0, class_weight='balanced',random_state = 0, fit_intercept=False, multi_class="multinomial", max_iter = 1000).fit(X[idx_sel], Y[idx_sel])
    pre_model = cl.predict(X)
    s = cal_acc(pre_model, Y)

    W = cl.coef_
    gamma = cl.predict_proba(X)
    gamma = np.delete(gamma, -1, axis = 1)
    G = [np.diag(gamma[i]) - np.outer(gamma[i], gamma[i].T) for i in range(n)]
    G = np.asarray(G)
    w, v = np.linalg.eigh(G)
    neg_id = (w<0)
    w[neg_id] = 0.0

    w = np.sqrt(w)
    Q = np.empty((n, c-1, c-1))
    Q = np.einsum('ijk,ik->ijk', v, w)
    P = [np.kron(X[i][:,None], Q[i]) for i in range(n)]
    P = np.asarray(P)
    D = np.einsum('ijk,kli->jl', P[idx_sel], P[idx_sel].transpose((2,1,0)),optimize=True)

    P = cp.asarray(P)
    I_U = cp.einsum('ijk,kli->jl', P, P.transpose((2,1,0)), optimize=True)
    I_U *= (1./n)

    # solve
    n1 = b + len(idx_sel)
    idx_pool = np.delete(np.arange(n), idx_sel)
    P = P[idx_pool]
    D = (1./n1) *cp.asarray(D)
    n_pool = n - len(idx_sel)
    D = cp.asarray(D)

    solver = BaitSolver(n1, D, P, I_U)
    sel = solver.run(b)
    sel = cp.asnumpy(sel)
    sel_ = idx_pool[sel]
    idx = np.concatenate((idx_sel, sel_))

    return idx
    

