import getpass, logging, socket
from typing import Any, List
from itertools import chain

import torch
from omegaconf.dictconfig import DictConfig
from omegaconf.omegaconf import OmegaConf
from pytorch_lightning.loggers import NeptuneLogger

from src.utils.metrics import calc_preds, get_step_metrics, get_epoch_metrics

API_LIST = {
    "neptune": {
    },
}


def get_username():
    return getpass.getuser()

def flatten_cfg(cfg: Any) -> dict:
    if isinstance(cfg, dict):
        ret = {}
        for k, v in cfg.items():
            flatten: dict = flatten_cfg(v)
            ret.update({
                f"{k}/{f}" if f else k: fv
                for f, fv in flatten.items()
            })
        return ret
    return {"": cfg}

def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
    logger = logging.getLogger(name)
    logger.setLevel(level)
    return logger

def get_neptune_logger(
    cfg: DictConfig, project_name: str,
    name: str, tag_attrs: List[str], log_db: str,
    offline: bool, logger: str,
):
    neptune_api_key = API_LIST["neptune"][get_username()]

    # flatten cfg
    args_dict = {
        **flatten_cfg(OmegaConf.to_object(cfg)),
        "hostname": socket.gethostname()
    }
    tags = tag_attrs
    tags.append(log_db)

    neptune_logger = NeptuneLogger(
        api_key=neptune_api_key,
        project_name=project_name,
        experiment_name=name,
        params=args_dict,
        tags=tags,
        offline_mode=offline,
    )

    try:
        # for unknown reason, must access this field otherwise becomes None
        print(neptune_logger.experiment)
    except BaseException:
        pass

    return neptune_logger

def log_data_to_neptune(model_class, data, data_name, data_type, suffix, split, ret_dict=None, detach_data=True):
    data_key = 'loss' if f'{data_name}_{data_type}' == 'total_loss' else f'{data_name}_{data_type}'
    model_class.log(f'{split}_{data_key}_{suffix}', data.detach(), prog_bar=True, sync_dist=(split != 'train'))
    if ret_dict is not None:
        ret_dict[data_key] = data.detach() if detach_data else data
    
    return ret_dict

def log_step_losses(model_class, loss_dict, ret_dict, split):
    ret_dict = log_data_to_neptune(model_class, loss_dict['loss'], 'total', 'loss', 'step', split, ret_dict, detach_data=False)
    for key in ['task', 'aux', 'kd', 'kd_input', 'kd_target']:
        if f'{key}_loss' in loss_dict.keys():
            ret_dict = log_data_to_neptune(model_class, loss_dict[f'{key}_loss'], key, 'loss', 'step', split, ret_dict, detach_data=False)

    return ret_dict

def log_epoch_losses(model_class, outputs, split):
    loss = torch.stack([x['loss'] for x in outputs]).mean()
    log_data_to_neptune(model_class, loss, 'total', 'loss', 'epoch', split)
    for key in ['task', 'aux', 'kd', 'kd_input', 'kd_target']:
        if f'{key}_loss' in outputs[0].keys():
            key_loss = torch.stack([x[f'{key}_loss'] for x in outputs]).mean()
            log_data_to_neptune(model_class, key_loss, key, 'loss', 'epoch', split)

def log_epoch_metrics(model_class, outputs, split, lm_mode):
    assert lm_mode in ['task', 'aux']
    prefix = 'aux_' if lm_mode == 'aux' else ''

    preds_ = list(chain.from_iterable([x[f'{prefix}pred_label'] for x in outputs]))
    labels_ = list(chain.from_iterable([x['label'] for x in outputs]))

    preds = torch.LongTensor([x == labels_[i] for i, x in enumerate(preds_)])
    labels = torch.ones(len(labels_)).long()

    perf_metrics = get_step_metrics(preds, labels, model_class.perf_metrics[lm_mode])
    perf_metrics = get_epoch_metrics(model_class.perf_metrics[lm_mode])
    log_data_to_neptune(model_class, perf_metrics['acc'], f'{prefix}acc', 'metric', 'epoch', split)