import numpy as np
import torch as t
from colorama import Fore, Style
from loguru import logger

import wandb
from auto_encoder.helpers.ae_metrics import AutoEncoderMetrics


def _train_logging(
    use_wandb: bool,
    sample_num: int,
    loss: float,
    l1_sparsity_coef: float,
    learning_rate: float,
    auxiliary_balancing_loss_coef: float,
    expert_importance_loss_coef: float,
    capacity_factor: float,
    grad_norm: float,
    grad_scale: float,
    num_tokens: int,
    stochastic_topk_temperature: float,
):
    if use_wandb:
        wandb.log(
            data={
                "train_loss": loss,
                "grad_norm": grad_norm,
                "grad_scale": grad_scale,
                "num_tokens": num_tokens,
                "scheduler_params": {
                    "learning_rate": learning_rate,
                    "l1_sparsity_coef": l1_sparsity_coef,
                    "auxiliary_balancing_loss_coef": auxiliary_balancing_loss_coef,
                    "expert_importance_loss_coef": expert_importance_loss_coef,
                    "capacity_factor": capacity_factor,
                    "stochastic_topk_temperature": stochastic_topk_temperature,
                },
            },
            step=sample_num,
        )

        if sample_num < 100:
            logger.info(f"Loss: {loss:.4} after {sample_num} samples")
            logger.info(f"Grad norm: {grad_norm:.4}")
            # if l1_sparsity_coef:
            #     logger.info(f"L1 sparsity coef: {l1_sparsity_coef}")

    else:
        logger.info(f"Loss: {loss:.4} after {sample_num} samples")
        logger.info(f"Grad norm: {grad_norm:.4}")
        # if l1_sparsity_coef:
        #     logger.info(f"L1 sparsity coef: {l1_sparsity_coef}")


def eval_logging(
    train_sample_num: int,
    reduced_metrics: AutoEncoderMetrics,
    num_features: int,
    use_wandb: bool,
    density_penalty: float,
    time_elapsed_mins: float,
    total_num_steps: int,
    num_dead_features: int,
) -> None:
    assert reduced_metrics.downstream_loss_recovered is not None

    l0_norm = reduced_metrics.l0_sparsity_metric
    feature_density = reduced_metrics.feature_density(num_features)
    proxy_sweep_metric = reduced_metrics.proxy_sweep_metric(num_features, density_penalty)

    if reduced_metrics.expert_usage is not None and isinstance(
        reduced_metrics.expert_usage, t.Tensor
    ):
        expert_usage_array: np.ndarray = np.round(
            reduced_metrics.expert_usage.cpu().detach().numpy(), 3
        )
        expert_usage = expert_usage_array.tolist()
        expert_usage = sorted(expert_usage, reverse=True)
        expert_usage_str = "%, ".join([f"{x * 100:.1f}" for x in expert_usage]) + "%"
    else:
        logger.debug("Expert usage N/A")
        expert_usage_str = None

    logger.info(reduced_metrics)
    if train_sample_num > 0:
        time_remaining_mins = (
            time_elapsed_mins / train_sample_num * (total_num_steps - train_sample_num)
        )
    else:
        time_remaining_mins = 0

    downstream_loss_recovered = reduced_metrics.downstream_loss_recovered * 100

    decorr_score_perc = (
        reduced_metrics.decorr_score * 100 if reduced_metrics.decorr_score else 0
    )

    nfm_inf_loss_perc = (
        reduced_metrics.nfm_inf_loss * 100 if reduced_metrics.nfm_inf_loss else 0
    )
    nfm_loss_perc = reduced_metrics.nfm_loss * 100 if reduced_metrics.nfm_loss else 0

    logger.info(
        f"""{(wandb.run.name if wandb.run else '')}

After {Fore.CYAN}{train_sample_num}/{total_num_steps}{Style.RESET_ALL} samples:
                Naive Description length: {reduced_metrics.naive_description_length_bits:.4f}
                L0 (sparsity) norm: {l0_norm:.4f}
                Feature density: {feature_density*100:.2f}%
                Proxy sweep metric: {proxy_sweep_metric:.4}
                Downstream loss recovered: {Fore.CYAN}{downstream_loss_recovered:.2f}%{Style.RESET_ALL}
                Expert usage: {expert_usage_str}
                Proportion tokens router: {reduced_metrics.proportion_tokens_routed:.4}
                Decorrelation score (%): {decorr_score_perc:.2f}
                Dead Features: {num_dead_features}/{num_features}

                Time elapsed: {time_elapsed_mins:.4} minutes
                Expected time remaining: {Fore.CYAN}{int(time_remaining_mins)}{Style.RESET_ALL} minutes
                """
    )

    if use_wandb:

        ### LOG METRICS ###

        wandb.log(
            data={
                "eval_loss": reduced_metrics.overall_loss,
                "proxy_sweep_metric": proxy_sweep_metric,
                "eval_metrics": {
                    "eval_acts_l0_norm": l0_norm,
                    "eval_feature_density_%": feature_density * 100,
                    "eval_downstream_loss_recovered_%": downstream_loss_recovered,
                    "eval_expert_usage": expert_usage_str if expert_usage_str else "",
                    "eval_proportion_tokens_routed": (
                        reduced_metrics.proportion_tokens_routed
                        if reduced_metrics.proportion_tokens_routed
                        else 0.0
                    ),
                    "eval_naive_description_length": reduced_metrics.naive_description_length_bits,
                    "eval_decorrelation_score_%": decorr_score_perc,
                },
                "num_dead_features": num_dead_features,
            },
            step=train_sample_num,
        )

        ### LOG LOSSES ###

        wandb.log(
            data={
                "eval_losses": {
metrics.mse_reconstruction_loss,
                    "eval_l1_sparsity_loss": (
                        reduced_metrics.l1_sparsity_loss
                        if reduced_metrics.l1_sparsity_loss
                        else 0.0
                    ),
                    "eval_gating_reconstruction_loss": (
                        reduced_metrics.gating_reconstruction_loss
                        if reduced_metrics.gating_reconstruction_loss
                        else 0.0
                    ),
                    "eval_switch_load_balancing_loss": (
                        reduced_metrics.switch_load_balancing_loss
                        if reduced_metrics.switch_load_balancing_loss
                        else 0.0
                    ),
                    "eval_router_z_loss": (
                        reduced_metrics.router_z_loss if reduced_metrics.router_z_loss else 0.0
                    ),
                    "eval_expert_importance_loss": (
                        reduced_metrics.expert_importance_loss
                        if reduced_metrics.expert_importance_loss
                        else 0.0
                    ),
                    "eval_nfm_loss_%": (nfm_loss_perc),
                    "eval_nfm_inf_loss_%": (nfm_inf_loss_perc),
                    "eval_multi_info_loss": (
                        reduced_metrics.multi_info_loss
                        if reduced_metrics.multi_info_loss
                        else 0.0
                    ),
                    "eval_hessian_penalty_loss": (
                        reduced_metrics.hessian_penalty_loss
                        if reduced_metrics.hessian_penalty_loss
                        else 0.0
                    ),
                    "eval_jump_sparsity_loss": (
                        reduced_metrics.jump_l0_loss if reduced_metrics.jump_l0_loss else 0.0
                    ),
                    "eval_codebook_loss": (
                        reduced_metrics.codebook_loss if reduced_metrics.codebook_loss else 0.0
                    ),
                    "eval_commitment_loss": (
                        reduced_metrics.commitment_loss
                        if reduced_metrics.commitment_loss
                        else 0.0
                    ),
                    "eval_feature_reconstruction_loss": (
                        reduced_metrics.feature_reconstruction_loss
                        if reduced_metrics.feature_reconstruction_loss
                        else 0.0
                    ),
                }
            },
            step=train_sample_num,
        )
