import ut_scores
import time
import numpy as np
import torch

def show_verbose(itr, histos, verbose_interval=10):
    if itr % verbose_interval == 0:
        history_loss, history_sum, history_time = histos
        
        score_kl = history_loss["kl_div"][-1]
        score_α  = history_loss["alpha_div"][-1]
        score_Renyi = history_loss["Renyi_div"][-1]
        score_L2 = history_loss["L2"][-1]
        score_fit = history_loss["fit"][-1]
    
        total_sum = history_sum[-1]
        elapsed_time = history_time[-1]
        
        print(f"Iter: {itr:5d} | Errors: {score_α:.5f} {score_Renyi:.5f} {score_kl:.5f} {score_L2:.5f} | Sum: {total_sum:.5f} | {elapsed_time:.5f} sec.")

    
def update_histos(histos, T, P, α, start_time):
    history_loss, history_sum, history_time = histos

    alpha_score_val = ut_scores.mix_alpha_div(T,P,α)
    renyi_score_val = ut_scores.mix_renyi_div(T,P,α)
    kl_score_val = ut_scores.kl_div(T,P)
    L2score, fitscore = ut_scores.l2_score(T, P)

    if isinstance(T, torch.Tensor):
        alpha_score_val = alpha_score_val.clone().detach().numpy().item()
        renyi_score_val = renyi_score_val.clone().detach().numpy().item()
        kl_score_val = kl_score_val.clone().detach().numpy().item()
        L2score = L2score.clone().detach().numpy().item()
        fitscore = fitscore.clone().detach().numpy().item()
        
    ## update history_loss
    history_loss["alpha_div"].append(alpha_score_val)
    history_loss["Renyi_div"].append(renyi_score_val)
    history_loss["kl_div"].append(kl_score_val)

    history_loss["L2"].append( L2score )
    history_loss["fit"].append( fitscore )

    ## update history_sum
    if isinstance(T, torch.Tensor):
        #history_sum.append(torch.sum(P))
        total_val = torch.sum(P).clone().detach().numpy().item()
        history_sum.append(total_val)
    else:
        history_sum.append(np.sum(P))

    ## update history_time
    elapsed_time = time.perf_counter() - start_time
    history_time.append(elapsed_time)