#coding:utf-8

import numpy as np
from scipy.sparse import identity, dia_array, csr_array, find
from tqdm import tqdm
from sksparse.cholmod import CholmodNotPositiveDefiniteError

from bdivrec.fabaphe.utils import power, SAMPLE, MAX, fast_inverse, fast_inverse_fullmat, Cholesky, stds
from bdivrec.fabaphe.fabaphe import Recommender

all_baselines_known = ["QDDecomposition", "ConditionalDPP", "EpsGreedy", "kMarkovDPP", "MMR"]

## Regular quality-diversity decomposition for DPPs
class QDDecomposition(Recommender):
    def __init__(self, params):
        super().__init__(params)
        self.name = "QDDecomposition"
        
    def compute_dpp(self, u, env, K, S=None):
        if (S is not None):
            qs = env.feedback(S, u)
            Phi = env.item_embs(S)
        else:
            qs = env.feedback_slice(0, env.nitem, u)
            Phi = env.item_embs_slice(0, env.nitem)
        #qs = power(qs, self.c*self.lbd/2)
        N = qs.shape[0]
        QS = csr_array(dia_array((qs.data, 0), shape=(N, N)))
        KS, _ = K(Phi)
        if (K.beta is not None):
            KS *= (KS>K.beta).astype(int)
        if (KS.shape[0]==KS.shape[1]):
            KS = KS @ KS.T
        #KS = power(KS, self.c*(1-self.lbd))
        return QS, KS ## diagonal matrix, vector

## Classical conditional DPP
class ConditionalDPP(Recommender):
    def __init__(self, params):
        super().__init__(params)
        self.name = "ConditionalDPP"
        
    def compute_dpp(self, u, env, K, S=None):
        if (S is not None):
            qs = env.feedback(S, u)
            Phi = env.item_embs(S)
        else:
            qs = env.feedback_slice(0, env.nitem, u)
            Phi = env.item_embs_slice(0, env.nitem)
        #qs = power(qs, self.c*self.lbd/2)
        N = qs.shape[0]
        QS = csr_array(dia_array((qs.data, 0), shape=(N, N)))
        KS, _ = K(Phi)
        H = env.get_user_hist(u)
        if (len(H)==0):
            return QS, KS
        Phi_H = env.item_embs(H)
        KH, _ = K(Phi_H)
        KH_inv = fast_inverse(KH, self.eta)
        KSS = identity(KH.shape[1]) - KH.T @ KH_inv @ KH
        KS_p = Cholesky(KSS, eta=0)
        ## KS = (KS @ KS_p) @ (KS @ KS_p).T
        ## KS = KS @ KSS @ KS.T
        KS = KS @ KS_p
        if (K.beta is not None):
            KS *= (KS>K.beta).astype(int)
        if (KS.shape[0]==KS.shape[1]):
            KS = KS @ KS.T
        #KS = power(KS, self.c*(1-self.lbd))
        return QS, KS ## diagonal matrix, matrix

## epsilon-greedy DPP
class EpsGreedy(Recommender):
    def __init__(self, params):
        assert "epsilon" in params ## value of epsilon (percentage of greedy strategy)
        super().__init__(params)
        self.name = "EpsGreedy"
        
    def compute_dpp(self, u, env, K, S=None):
        if (S is not None):
            qs = env.feedback(S, u)
            Phi = env.item_embs(S)
        else:
            qs = env.feedback_slice(0, env.nitem, u)
            Phi = env.item_embs_slice(0, env.nitem)
        #qs = power(qs, self.c*self.lbd/2)
        N = qs.shape[0]
        QS = csr_array(dia_array((qs.data, 0), shape=(N, N)))
        KS, _ = K(Phi)
        #KS = power(KS, self.c*(1-self.lbd))
        if (K.beta is not None):
            KS *= (KS>K.beta).astype(int)
        if (KS.shape[0]==KS.shape[1]):
            KS = KS @ KS.T
        draw = self.rng.choice([0,1], size=1, p=[1-self.epsilon, self.epsilon])
        if (draw):
            return QS, KS ## regular QD decomposition
        else:
            return identity(QS.shape[0]), KS ## ignore quality scores
        
# k-Markov Determinantal Process
# See https://www.alexkulesza.com/pubs/markov_uai12.pdf
# Contrary to the other algorithms, they take as input the length T of the trajectory
# and should be called only once
class kMarkovDPP(Recommender):
    def __init__(self, params):
        assert "T" in params
        super().__init__(params)
        self.qd_dec = QDDecomposition(params)
        self.name = "kMarkovDPP"
        
    def subsetId(self, D, A): # gives I_{Y/A}, defined in Equation 5
        I = identity(D)
        for i in A:
            I.data[0,i] = 0
        return I

    def subsetmat(self, M, A): # gives [M]_{Y/A}, defined in Equation 7
        ids = [i for i in range(M.shape[0]) if (i not in A)]
        return M[ids,:]
        
    def single_update(self, L, Yt_ids, eta=0): ## Adapted from Algorithm 4
        D = L.shape[0]
        It = self.subsetId(D, Yt_ids) ## Equation 30
        LL_inv = fast_inverse_fullmat(L + It, eta=eta)
        LL_inv_sub = self.subsetmat(LL_inv, Yt_ids) ## filter rows
        LL_inv_sub = LL_inv_sub.T
        LL_inv_sub = self.subsetmat(LL_inv_sub, Yt_ids) ## filter columns
        LL_inv_sub = LL_inv_sub.T
        LL = fast_inverse_fullmat(LL_inv_sub, eta=0)
        I = identity(LL.shape[0])
        Lt = LL - I
        return I, Lt
        
    def initial_sample(self, QS, KS, B, K, seed, eta): ## Algorithm 4
        Z1_ids = SAMPLE(QS, KS, 2*B, K, seed, eta)
        self.rng.shuffle(Z1_ids)
        Y1_ids = [Z1_ids[i] for i in range(B)]
        return Y1_ids
            
    def recommend(self, B, u, env, K, S=None):
        assert u >= 0
        assert u < env.nuser
        if (S is not None):
            assert all([i>=0 and i<=env.nitem for i in S])
        #assert env.nitem<=5000
        H = env.get_user_hist(u)
        QS, KS = self.qd_dec.compute_dpp(u, env, K, S=S)
        if (len(H)==0):
            Yt_ids = self.initial_sample(QS, KS, B, K, self.seed, self.eta)
            recs = [Yt_ids]
            start_sample=1
        else:
            Yt_ids = H ## not in the original paper
            recs = []
            start_sample=0
        QQ = power(QS, self.lbd/2*self.c) ## tradeoff
        if (K.beta is not None):
            KS *= (KS>K.beta).astype(int)
        KK = power(KS @ KS.T, (1-self.lbd)*self.c) ## tradeoff
        L = QQ @ KK @ QQ ## not tractable for millions of items
        for it in (pbar := tqdm(
            range(start_sample, self.T),
            position=4,
            leave=False,
        )):
            try:
                I, Lt = self.single_update(L, Yt_ids, eta=self.eta)
            except CholmodNotPositiveDefiniteError: ## not diverse enough
                I, Lt = self.single_update(QQ @ QQ, Yt_ids, eta=self.eta) ## not in the original paper
            S_ids = eval(self.rec_type)(I, Lt, B, K, self.seed, eta=self.eta) ## recommendations on Omega\Yt_ids
            all_remaining_items = [item for item in range(QS.shape[0]) if (item not in Yt_ids)] ## all items but already recommended
            #Yt_ids = [all_remaining_items[item] for item in S_ids] ## recs on Omega
            Yt_ids += [all_remaining_items[item] for item in S_ids] ## recs on Omega ## not in the original paper
            #recs += [Yt_ids]
            recs += [Yt_ids[-B:]] ## not in the original paper
            pbar.set_description(f"{self.name} in trajectory (T={it+1}/{self.T})")
        assert len(recs) == self.T
        assert all([len(r) == B for r in recs])
        if (S is not None):
            return [S[i] for i in recs]
        else:
            return recs
            
## Maximum Marginal Relevance (MMR)
## We adapt it to our setting by considering sim1 = relevance, sim2 = kernel similarity
## and the query is actually the user context
class MMR(Recommender):
    def __init__(self, params):
        super().__init__(params)
        self.name = "MMR"
        
    def compute_dpp(self, u, env, K, S=None):
        raise ValueError("MMR does not use DPPs.")
        
    def recommend(self, B, u, env, K, S=None):
        assert u >= 0
        assert u < env.nuser
        if (S is not None):
            assert all([i>=0 and i<=env.nitem for i in S])
        if (S is not None):
            qs = env.feedback(S, u)
            Phi = env.item_embs(S)
        else:
            qs = env.feedback_slice(0, env.nitem, u)
            Phi = env.item_embs_slice(0, env.nitem)
        H = env.get_user_hist(u)
        ## iteratively build the set
        #recs = np.argsort(scores.data)[-B:].tolist()
        recs = []
        recs += H
        while (len(recs)<B+len(H)):
            if (len(recs)==0):
                scores = self.lbd*qs
            else:
                Phi_H = env.item_embs(recs)
                KS, KH = K(Phi, Phi_H)
                if (K.beta is not None):
                    KS *= (KS>K.beta).astype(int)
                    KH *= (KH>K.beta).astype(int)
                KK = (KS @ KH.T).max(axis=1)
                scores = self.lbd*qs-(1-self.lbd)*KK.reshape(qs.shape)
            scores[recs] = -float("inf")
            row, _, _ = find(scores==scores.max())
            top_items = row.ravel().tolist()
            top_item = self.rng.choice(top_items, size=1)[0] ## break ties at random
            recs.append(top_item) ## break ties at random
        recs = recs[len(H):]
        assert len(recs) == B
        if (S is not None):
            return [S[i] for i in recs]
        else:
            return recs
        
if __name__ == "__main__":
    from time import time
    from kernels import DenudedKernel
    sys.path.insert(0,"HAN/")
    from known_envs import SyntheticCosine
    from utils import SAMPLE, MAX
    seed = 1234
    eta = 1e-3
    c = 2
    lbd = 0.5
    epsilon = lbd/2
    beta = 0.28
    nc = 10
    nchunks=1000 #100000
    nitem=1000 #1000000
    data_folder = "../../datasets/Synthetic_tests_baselines_known"
    env = SyntheticCosine(dict(name=data_folder, nitem=nitem, nchunks=nchunks, nuser=100, d=10, seed=seed, quantize_digit=3, new=True))
    K = DenudedKernel(kernel="linear", n_components=nc, nchunks=nchunks//20, seed=seed, beta=beta)
    u = 0
    B = 5
    ## add history to user 0
    env.update_user_hist(u, np.random.choice(range(nitem), size=5).tolist())
    for rec_name in ["kMarkovDPP", "QDDecomposition", "EpsGreedy", "ConditionalDPP", "MMR"]:
        print(rec_name)
        if (rec_name not in ["kMarkovDPP","MMR"]):
            if (env.nitem>5000):
                continue
            rec = eval(rec_name)(dict(lbd=lbd, c=c, eta=eta, rec_type="SAMPLE", seed=seed, epsilon=epsilon))
            start_T = time()
            Q, M = rec.compute_dpp(u, env, K)
            #print((M!=0).mean()) ## target .0001
            print(f"Time L-matrix computation (omega-wide): {np.round(time()-start_T,3)} seconds.")
            start_T = time()
            recs = SAMPLE(Q, M, B, K, seed, eta=eta)
            print(f"Time Recommendation after sampling (omega-wide): {np.round(time()-start_T,3)} seconds.")
            print(recs)
            start_T = time()
            recs = MAX(Q, M, B, K, seed, eta=eta)
            print(f"Time Recommendation after maximization (omega-wide): {np.round(time()-start_T,3)} seconds.")
            print(recs)
            print("")
        else:
            rec = eval(rec_name)(dict(lbd=lbd, c=c, eta=eta, rec_type="SAMPLE", T=1, seed=seed, epsilon=epsilon))
            start_T = time()
            recs = rec.recommend(B, u, env, K)
            print(f"Time Recommendation after sampling (omega-wide): {np.round(time()-start_T,3)} seconds.")
            if (rec_name == "kMarkovDPP"):
                print(recs[0])
            else:
                print(recs)
            if (rec_name == "kMarkovDPP"):
                rec = eval(rec_name)(dict(lbd=lbd, c=c, eta=eta, rec_type="MAX", T=1, seed=seed, epsilon=epsilon))
                start_T = time()
                recs = rec.recommend(B, u, env, K)
                print(f"Time Recommendation after maximization (omega-wide): {np.round(time()-start_T,3)} seconds.")
                print(recs[0])
            print("")
        
    #proc = sb.Popen(f"rm -rf {data_folder}".split(" "))
    #proc.wait()
