import torch
import numpy as np
import importlib
import ut_algo
import ut_scores
import math
import time
#importlib.reload(scores)
#importlib.reload(utils_algo)
from torch import tensor

torch.set_printoptions(precision=10)
from torch.autograd import Variable
def mix_alpha_grad_cp(T, rank, αs, lr=0.01, max_iter=100, verbose=True, verbose_interval=1, optim_method="GD"):
    start_time = time.perf_counter()
    I, J, K = np.shape(T)

    # To get renyi divergecen, input T need to be torch tensor
    if not( isinstance(T, torch.Tensor ) ):
        T = torch.from_numpy(T)
    
    # Real values factors to be optimized
    ϕ = torch.rand(I, rank, requires_grad=True)
    ψ = torch.rand(J, rank, requires_grad=True)
    η = torch.rand(K, rank, requires_grad=True)
    
    # Factors
    A = softmax_transform_factor(ϕ,rank)
    B = softmax_transform_factor(ψ,rank)
    C = softmax_transform_factor(η,rank)

    # Low-rank matrix
    P = torch.einsum('ir,jr,kr->ijk', A, B, C)
    #P = A @ B.T

    history_loss = {"alpha_div":[], "Renyi_div":[], "kl_div":[], "L2":[], "fit":[]}
    history_sum  = []
    history_time = []
    histos = [history_loss, history_sum, history_time]

    params = [ϕ,ψ,η]
    if optim_method == "SGD":
        optimizer = torch.optim.SGD(params, lr=lr)
    if optim_method == "RMSprop":
        optimizer = torch.optim.RMSprop(params, lr=lr)
    if optim_method == "Adam":
        optimizer = torch.optim.Adam(params, lr=lr)

    if verbose:
        ## Header of verbose
        print(f"\n{'Iteration':<22} {'α-div':<6} {'Renyi':<8} {'KL-div':<6}  {'L2':<7} {'Total sum':<13} {'Elapsed time'}")

    ut_algo.update_histos(histos, T, P, αs, start_time)
    for itr in range(max_iter):
        loss = ut_scores.mix_renyi_div(T,P,αs)
        loss.backward()

        if optim_method == "GD":
            # update parameters by autogradient
            with torch.no_grad():
                ϕ -= ϕ.grad * lr
                ψ -= ψ.grad * lr
                η -= η.grad * lr
                
                ϕ.grad.zero_()
                ψ.grad.zero_()
                η.grad.zero_()
        else:
            optimizer.step()
            optimizer.zero_grad()

        # Reconstract low-rank P
        A = softmax_transform_factor(ϕ,rank)
        B = softmax_transform_factor(ψ,rank)
        C = softmax_transform_factor(η,rank)
        
        #P = A @ B.T
        P = torch.einsum('ir,jr,kr->ijk', A, B, C)
        
        with torch.no_grad():
            ut_algo.update_histos(histos, T, P, αs, start_time)
        if verbose:
            ut_algo.show_verbose(itr, histos, verbose_interval=verbose_interval)
        
    return A, B, P, histos

def softmax_transform_factor(ϕ, rank):
    #return 1.0 / torch.sqrt(torch.tensor([rank])) * torch.exp(ϕ) / torch.sum( torch.exp(ϕ), axis=0, keepdims=True)
    return 1.0 / torch.pow(torch.tensor([rank]), 1/3) * torch.exp(ϕ) / torch.sum( torch.exp(ϕ), axis=0, keepdims=True)

