import torch
import numpy as np
import importlib
import ut_algo
import ut_scores
import math
import time
#importlib.reload(scores)
#importlib.reload(utils_algo)
from torch import tensor

torch.set_printoptions(precision=10)
from torch.autograd import Variable
def mix_alpha_grad_ctt(T, rankcp, ranktucker, ranktrain, αs, lr=0.01, max_iter=100, verbose=True, verbose_interval=1, optim_method="GD"):
    start_time = time.perf_counter()
    I, J, K = np.shape(T)

    assert len(ranktrain) == np.ndim(T) - 1
    assert len(ranktucker) == np.ndim(T)

    # To get renyi divergecen, input T need to be torch tensor
    if not( isinstance(T, torch.Tensor ) ):
        T = torch.from_numpy(T)

    learn_cp = True
    learn_Tucker = True
    learn_train = True
    if rankcp == 0:
        learn_cp = False
    if ranktucker[0] == 0:
        learn_Tucker = False
    if ranktucker[0] == 0:
        learn_train = False
        

    params = []
    P = 0
    if learn_cp:
        # For CP
        ϕ = torch.rand(I, rankcp, requires_grad=True)
        ψ = torch.rand(J, rankcp, requires_grad=True)
        η = torch.rand(K, rankcp, requires_grad=True)
        A = softmax_transform_factor(ϕ,rankcp)
        B = softmax_transform_factor(ψ,rankcp)
        C = softmax_transform_factor(η,rankcp)
        Pcp = torch.einsum('ir,jr,kr->ijk', A, B, C)
        params.extend([ϕ,ψ,η])
        
    else:
        Pcp = 0

    if learn_Tucker:
        # For Tucker
        ϕQ = torch.rand(ranktucker[0], ranktucker[1], ranktucker[2], requires_grad=True)
        ϕX = torch.rand(I, ranktucker[0], requires_grad=True)
        ϕY = torch.rand(J, ranktucker[1], requires_grad=True)
        ϕZ = torch.rand(K, ranktucker[2], requires_grad=True)
        Q, X, Y, Z = softmax_transform_factor_tucker(ϕQ, ϕX, ϕY, ϕZ)
        Ptucker = torch.einsum('stu,is,jt,ku -> ijk', Q, X, Y, Z)
        params.extend([ϕQ,ϕX,ϕY,ϕZ])
    else:
        Ptucker = 0

    if learn_train:
        # For train
        ϕG = torch.rand(I, ranktrain[0], requires_grad=True)
        ϕH = torch.rand(ranktrain[0], J, ranktrain[1], requires_grad=True)
        ϕF = torch.rand(ranktrain[1], K, requires_grad=True)
        G, H, F = softmax_transform_factor_train_core(ϕG, ϕH, ϕF, ranktrain)
        Ptrain = torch.einsum('is,sjt,tk -> ijk', G, H, F)
        params.extend([ϕG,ϕH,ϕF])
    else:
        Ptrain = 0
    
    # mixture weight
    
    # Low-rank matrix
    if learn_cp and learn_train and learn_Tucker:
        μ = torch.rand(3 , requires_grad=True)
        w = softmax_transform_weights(μ)
        P = w[0] * Pcp + w[1] * Ptucker + w[2] * Ptucker
        params.append(μ)
    elif learn_cp and not(learn_train) and not(learn_Tucker):
        P = Pcp
    else:
        arise("error")
    
    assert abs( torch.sum(P) - 1 ) < 1.0e-5, "P is not normalized"
    assert abs( torch.sum(T) - 1 ) < 1.0e-5, "T is not normalized"
    
    history_loss = {"alpha_div":[], "Renyi_div":[], "kl_div":[], "L2":[], "fit":[]}
    history_sum  = []
    history_time = []
    histos = [history_loss, history_sum, history_time]

    if optim_method == "SGD":
        optimizer = torch.optim.SGD(params, lr=lr)
    if optim_method == "RMSprop":
        optimizer = torch.optim.RMSprop(params, lr=lr)
    if optim_method == "Adam":
        optimizer = torch.optim.Adam(params, lr=lr)
    if optim_method == "Adagrad":
        optimizer = torch.optim.Adagrad(params, lr=lr)

    if verbose:
        ## Header of verbose
        print(f"\n{'Iteration':<22} {'α-div':<6} {'Renyi':<8} {'KL-div':<6}  {'L2':<7} {'Total sum':<13} {'Elapsed time'}")

    ut_algo.update_histos(histos, T, P, αs, start_time)
    for itr in range(max_iter-1):
        loss = ut_scores.mix_renyi_div(T,P,αs)
        loss.backward()

        if optim_method == "GD":
            # update parameters by autogradient
            with torch.no_grad():
                ϕ -= ϕ.grad * lr
                ψ -= ψ.grad * lr
                η -= η.grad * lr
                
                ϕ.grad.zero_()
                ψ.grad.zero_()
                η.grad.zero_()
        else:
            optimizer.step()
            optimizer.zero_grad()

        # Reconstract low-rank P
        if learn_cp:
            A = softmax_transform_factor(ϕ,rankcp)
            B = softmax_transform_factor(ψ,rankcp)
            C = softmax_transform_factor(η,rankcp)
            Pcp = torch.einsum('ir,jr,kr->ijk', A, B, C)

        if learn_Tucker:
            Q, X, Y, Z = softmax_transform_factor_tucker(ϕQ, ϕX, ϕY, ϕZ)
            Ptucker = torch.einsum('stu,is,jt,ku -> ijk', Q, X, Y, Z)

        if learn_train:
            G, H, F = softmax_transform_factor_train_core(ϕG, ϕH, ϕF, ranktrain)
            Ptrain = torch.einsum('is,sjt,tk -> ijk', G, H, F)

        if learn_cp and learn_Tucker and learn_train:
            w = softmax_transform_weights(μ)
            P = w[0] * Pcp + w[1] * Ptucker + w[2] * Ptrain
        elif learn_cp and not(learn_Tucker) and not(learn_train):
            P = Pcp
            
        with torch.no_grad():
            ut_algo.update_histos(histos, T, P, αs, start_time)
        if verbose:
            ut_algo.show_verbose(itr, histos, verbose_interval=verbose_interval)
        
    return P, histos

def softmax_transform_factor(ϕ, rank):
    #return 1.0 / torch.sqrt(torch.tensor([rank])) * torch.exp(ϕ) / torch.sum( torch.exp(ϕ), axis=0, keepdims=True)
    return 1.0 / torch.pow(torch.tensor([rank]), 1/3) * torch.exp(ϕ) / torch.sum( torch.exp(ϕ), axis=0, keepdims=True)

def softmax_transform_factor_tucker(ϕQ, ϕX, ϕY, ϕZ):
    Q = torch.exp(ϕQ) / torch.sum( torch.exp(ϕQ) )
    X = 1.0  * torch.exp(ϕX) / torch.sum( torch.exp(ϕX), axis=0, keepdims=True )
    Y = 1.0  * torch.exp(ϕY) / torch.sum( torch.exp(ϕY), axis=0, keepdims=True )
    Z = 1.0  * torch.exp(ϕZ) / torch.sum( torch.exp(ϕZ), axis=0, keepdims=True )
    return Q, X, Y, Z

def softmax_transform_factor_train_core(ϕG, ϕH, ϕF, train_rank):
    G = 1.0 / train_rank[0] * torch.exp(ϕG) / torch.sum( torch.exp(ϕG), axis=0, keepdims=True )
    H = 1.0 / train_rank[1] * torch.exp(ϕH) / torch.sum( torch.exp(ϕH), axis=1, keepdims=True )
    F = 1.0  * torch.exp(ϕF) / torch.sum( torch.exp(ϕF), axis=1, keepdims=True )
    return G, H, F

def softmax_transform_weights(μ):
    w1 = torch.exp(μ[0]) / torch.sum( torch.exp(μ) )
    w2 = torch.exp(μ[1]) / torch.sum( torch.exp(μ) )
    w3 = torch.exp(μ[2]) / torch.sum( torch.exp(μ) )
    return w1, w2, w3
