import os

import ANONYMOUS.io as jio
import numpy as np
import torch as th


def linear_step(total_epoch, cur_epoch):
    return np.max(((total_epoch - cur_epoch - 1) / total_epoch, 0.0))


def temp_step(cur_iter, temp_start, temp_decay):
    if cur_iter < 2:
        return temp_start
    return np.clip(temp_start * np.exp(-cur_iter, temp_decay), 1.0, None)


def temp_adjust_wrapper(cfg):
    def _temp_adjust(trainer):
        if cfg.model.enable_temp:
            if cfg.data.is_linear:
                cur_t = linear_step(
                    np.min([cfg.trainer.epochs / 2, 10]), trainer.epoch_cnt
                )
            else:
                cur_t = temp_step(
                    trainer.iter_cnt, cfg.model.temp, cfg.model.temp_decay
                )

            trainer.train_set.big_t = cur_t
            trainer.val_set.big_t = cur_t
            trainer.cur_monitor.update({"T": cur_t})

    return _temp_adjust


def dump_info(trainer):
    if hasattr(trainer, "latest_result"):
        info = trainer.latest_result
        info["run_path"] = (os.getcwd(),)
        jio.dump(trainer.fdump, info)


def loss2logz_info(loss):
    log_weight = -loss + loss.mean()
    unnormal_weight = th.exp(log_weight)
    weight = unnormal_weight / unnormal_weight.sum()
    return {
        "loss_lower_bound": -loss.mean(),
        "loss_upper_bound": th.sum(-weight * loss),
        "loss_unbiased": th.log(th.mean(th.exp(-log_weight))) - loss.mean(),
    }


def loss2ess_info(loss):
    log_weight = -loss + loss.mean()
    unnormal_weight = th.exp(log_weight)
    weight = unnormal_weight / unnormal_weight.sum()
    return {"ess": 1.0 / (weight * weight).sum() / len(weight)}
