#coding: utf-8

from math import ceil
from scipy.sparse import identity, lil_array, dia_array, csc_array, csr_array
from faiss import IndexFlat, METRIC_INNER_PRODUCT, normalize_L2
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import KDTree

from bdivrec.fabaphe.utils import power, chunks, SAMPLE, MAX, set_score

class Recommender(object):
    def __init__(self, params):
        assert "lbd" in params ## the relative weight for the relevance task
        assert params["lbd"] <= 1 and params["lbd"] >= 0
        assert "c" in params ## multiplicative factor to avoid numerical issues
        assert params["c"] > 1
        assert "eta" in params ## regularization factor
        assert params["eta"] >= 0
        assert "rec_type" in params ## sampling or maximization
        assert params["rec_type"] in ["SAMPLE", "MAX"]
        assert "seed" in params ## random seed (for sampling)
        self.name = "Recommender"
        for param in params:
            setattr(self, param, params[param])
        self.quantize_digit = -1
        self.rng = np.random.default_rng(self.seed)
            
    def recommend(self, B, u, env, K, S=None):
        assert u >= 0
        assert u < env.nuser
        if (S is not None):
            assert all([i>=0 and i<=env.nitem for i in S])
        QS, KS = self.compute_dpp(u, env, K, S=S)
        #from scipy.sparse.linalg import norm
        #print((self.name, norm(QS), norm(KS)))
        ## apply adaptive quality-relevance coefficient
        QQ = power(QS, self.lbd*self.c) ## tradeoff
        KK = power(KS, (1-self.lbd)*self.c) ## tradeoff
        recs = eval(self.rec_type)(QQ, KK, B, K, self.seed, eta=self.eta)
        assert len(recs) == B
        if (S is not None):
            return [S[i] for i in recs]
        else:
            return recs
        
    def compute_dpp(self, u, env, K, S=None):
        ## Return [Q, M] such that the L-matrix is (Q.M).(Q.M)^T
        raise NotImplemented
        
    def grad_value_f(self, rels, u, env, K, S, lbd=None):
        if (lbd is not None):
            ll = self.lbd
            self.lbd = lbd[0]
        _, M = self.compute_dpp(u, env, K, S=S)
        N = rels.shape[0]
        Q = csr_array(dia_array((rels.data, 0), shape=(N, N)))
        detQ, volM = set_score(Q), np.sqrt(set_score(M))
        QQ = power(Q, self.lbd*self.c) ## tradeoff
        KK = power(M, (1-self.lbd)*self.c) ## tradeoff
        val_lbd = set_score(QQ, KK)
        grad = -4*self.c*self.lbd*(0 if (volM==0) else np.log(volM)) + 4*self.c*self.lbd*(0 if (detQ==0) else np.log(detQ))
        if (lbd is not None):
            self.lbd = ll
        return grad, val_lbd
    
## version for known feedback model
class Fabaphe(Recommender):
    def __init__(self, params):
        assert "alpha" in params ## maximum fuzzy denuding distance
        if ("type_algo" not in params):
            self.type_algo = "FAISS"
        else:
            assert params["type_algo"] in ["FAISS", "kd-tree"]
        if ("nchunks" not in params):
            self.nchunks = 10 #10000
        super().__init__(params)
        self.name = "Fabaphe"
        
    ## https://github.com/facebookresearch/faiss/wiki/MetricType-and-distances#how-can-i-index-vectors-for-cosine-similarity
    def compute_tree(self, Phi_H):
        assert Phi_H.shape[0]>0
        MM = np.array(Phi_H.toarray(), dtype=np.float32)
        tree = IndexFlat(MM.shape[1], METRIC_INNER_PRODUCT)
        normalize_L2(MM)
        tree.add(MM)
        return tree
        
    def query_tree(self, x, tree):
        assert tree is not None
        XX = np.array(x.toarray(), dtype=np.float32)
        normalize_L2(XX)
        d, _ = tree.search(XX, 1)
        posids = np.unique(np.argwhere((d>=1-self.alpha).ravel())).tolist()
        return (d>=1-self.alpha).ravel()
        
    def compute_kdtree(self, Phi_H):
        assert Phi_H.shape[0]>0
        if (len(Phi_H.shape)==1):
            Phi_H = Phi_H.reshape(1, -1)
        return KDTree(Phi_H, metric="euclidean") ## = cosine on normalized embeddings (l2 norm = 1)

    def query_kdtree(self, x, tree):
        d, _ = tree.query(XX, 1)
        posids = np.unique(np.argwhere((d>=1-self.alpha).ravel())).tolist()
        return (d>=1-self.alpha).ravel()
        
    def compute_dpp(self, u, env, K, S=None):
        ## Return [Q, M] such that the L-matrix is (Q.M).(Q.M)^T
        if (S is not None):
            qs = env.feedback(S, u)
            Phi = env.item_embs(S)
        else:
            qs = env.feedback_slice(0, env.nitem, u)
            Phi = env.item_embs_slice(0, env.nitem)
        #qs = power(qs, self.c*self.lbd/2)
        N = qs.shape[0]
        QS = csr_array(dia_array((qs.data, 0), shape=(N, N)))
        KS, _ = K(Phi)
        if (K.beta is not None):
            KS *= (KS>K.beta).astype(int)
        PSH = lil_array(Phi.shape)
        H = env.get_user_hist(u)
        if (len(H)==0):
            #KS = power(KS, self.c*(1-self.lbd))
            if (KS.shape[0]==KS.shape[1]):
                KS = KS @ KS.T
            return QS, KS
        Phi_H = env.item_embs(H)
        KH, _ = K(Phi_H)
        #KPSH = lil_array(KS.copy())
        if (self.type_algo == "kd-tree"):
            tree = self.compute_kdtree(KH)#Phi_H)
        else:
            tree = self.compute_tree(KH)#Phi_H)
        N = Phi.shape[0]
        #e = lil_array(csc_array(identity(N))[:,:KS.shape[1]])
        for ii, lst in (pbar := tqdm(
            enumerate(chunks(N, self.nchunks)),
            position=4,
            leave=False,
        )):
            pbar.set_description(f"Finding neighbors {min((ii+1)*self.nchunks,N)}/{N}")
            x = KS[lst] #Phi[lst]
            if (self.type_algo == "kd-tree"):
                v = self.query_kdtree(x, tree)
            else:
                v = self.query_tree(x, tree)
            posids = np.unique(np.argwhere(v)).tolist()
            posids_S = [lst[i] for i in posids]
            #negids = np.argwhere(~v).ravel().tolist()
            PSH[posids_S] = Phi[lst][posids] #x[posids]
            #e[negids] = 0
        KPSH, _ = K(PSH)
        if (K.beta is not None):
            KPSH *= (KPSH>K.beta).astype(int)
        KSH = KS - KPSH #+ self.eta*e
        if (KSH.shape[0]==KSH.shape[1]):
            KSH = KSH @ KSH.T
        #KSH = power(KS - KPSH, self.c*(1-self.lbd))
        return QS, KSH

if __name__ == "__main__":
    from time import time
    from kernels import DenudedKernel
    import sys
    sys.path.insert(0,"../../experiments/known_setting/")
    from known_envs import SyntheticCosine
    seed = 1234
    eta = 1e-3
    c = 2
    lbd = 0.5
    epsilon = lbd/2
    beta = 0.28
    nc = 10
    nchunks=100000
    nitem=1000000
    data_folder = "../../datasets/Synthetic_tests_fabaphe"
    env = SyntheticCosine(dict(name=data_folder, nitem=nitem, nchunks=nchunks, nuser=100, d=10, seed=seed, quantize_digit=3, new=True))
    K = DenudedKernel(kernel="linear", n_components=nc, nchunks=nchunks//20, seed=seed, beta=beta)
    u = 0
    B = 5
    ## add history to user 0
    env.update_user_hist(u, np.random.choice(range(nitem), size=5).tolist())
    for rec_name in ["Fabaphe"]:
        print(rec_name)
        rec = eval(rec_name)(dict(lbd=lbd, c=c, eta=eta, alpha=0, rec_type="SAMPLE", nchunks=nchunks, seed=seed))
        start_T = time()
        Q, M = rec.compute_dpp(u, env, K)
        #print((M!=0).mean()) ## target .0001
        print(f"Time L-matrix computation (omega-wide): {np.round(time()-start_T,3)} seconds.")
        start_T = time()
        recs = SAMPLE(Q, M, B, K, seed, eta=eta)
        print(f"Time Recommendation after sampling (omega-wide): {np.round(time()-start_T,3)} seconds.")
        print(recs)
        start_T = time()
        recs = MAX(Q, M, B, K, seed, eta=eta)
        print(f"Time Recommendation after maximization (omega-wide): {np.round(time()-start_T,3)} seconds.")
        print(recs)
