import os
import logging
import numpy as np
from utils.common import count_parameters

from utils.constants import LOG_BASIC_KEYS


def get_logger(name="shared_logger", log_file=None):
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    # Always add handlers if they haven't been added yet
    has_stream = any(isinstance(h, logging.StreamHandler) for h in logger.handlers)
    has_file = (
        any(
            isinstance(h, logging.FileHandler)
            and h.baseFilename == os.path.abspath(log_file)
            for h in logger.handlers
        )
        if log_file
        else False
    )

    if not has_stream:
        ch = logging.StreamHandler()
        ch.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
        logger.addHandler(ch)

    if log_file and not has_file:
        os.makedirs(os.path.dirname(log_file), exist_ok=True)
        fh = logging.FileHandler(log_file)
        fh.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
        logger.addHandler(fh)

    return logger


logger = get_logger()


def log_basic_info(config):
    rjust = 10
    for k in LOG_BASIC_KEYS:
        logger.info(f"{k.rjust(rjust)}:\t{config[k]}")
    logger.info(f"{'data_path'.rjust(rjust)}:\t{config['data_args']['path']}")
    logger.info(f"")


def log_parameters(exp):
    if exp.config.model_type == "pothos":
        model_components = [
            exp.encoder,
            exp.recon_net,
            exp.init_encoder,
            exp.tv_module,
            exp,
        ]
        model_names = [
            "Encoder",
            "Recon Decoder",
            "Init Encoder",
            "Time Varying Encoder",
            "Total Parameters",
        ]
    else:
        model_components = [exp.encoder, exp]
        model_names = ["Encoder", "Total Parameters"]

    logger.info(f"Parameter Count:")
    for m, n in zip(model_components, model_names):
        n_parameters = count_parameters(m)
        logger.info(f"\t{n.rjust(20)}:\t{n_parameters}")
    logger.info(f"")


def log_scores(scores, name):
    logger.info(f"")
    logger.info(f"{name} Scores:")
    for k, v in scores.items():
        logger.info(f"\t{k}:\t{v:.4f}")
    logger.info(f"")


def get_log_report(report, score_name):
    return {k: v[score_name] for k, v in report.items() if isinstance(v, dict)}


def log_class_scores(results):
    log_scores({i: results[i] for i in ["acc", "auroc", "auprc"]}, "Classification")

    score_name = "f1-score"
    log_report = get_log_report(results["report"], score_name)
    log_scores(log_report, f"Classification {score_name}")

    class_acc = {j: i for j, i in enumerate(results["cm"].diagonal())}

    log_scores(class_acc, "Class Acc")


def log_semisupervised(results):
    ps = list(results[0].keys())
    logger.info(f"")
    logger.info(f"evaluating semisupervised classification")
    for p in ps:
        acc = np.mean([i[p]["acc"] for i in results])
        logger.info(f"\t{p}:\t{acc}")
    logger.info(f"")
