import os
import torch
import dateutil.tz
from datetime import datetime
import time
import logging
import numpy as np
from torch.optim.lr_scheduler import LambdaLR



def create_logger(log_dir, phase='train'):
    time_str = time.strftime('%Y-%m-%d-%H-%M')
    log_file = '{}_{}.log'.format(time_str, phase)
    final_log_file = os.path.join(log_dir, log_file)
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=str(final_log_file),
                        format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    return logger


def set_log_dir(root_dir, exp_name):
    path_dict = {}
    os.makedirs(root_dir, exist_ok=True)

    # set log path
    exp_path = os.path.join(root_dir, exp_name)
    now = datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    prefix = exp_path + '_' + timestamp
    os.makedirs(prefix)
    path_dict['prefix'] = prefix

    # set checkpoint path
    ckpt_path = os.path.join(prefix, 'Model')
    os.makedirs(ckpt_path)
    path_dict['ckpt_path'] = ckpt_path

    log_path = os.path.join(prefix, 'Log')
    os.makedirs(log_path)
    path_dict['log_path'] = log_path

    # set sample image path for fid calculation
    sample_path = os.path.join(prefix, 'Samples')
    os.makedirs(sample_path)
    path_dict['sample_path'] = sample_path

    return path_dict


def compute_hb_details(max_t, eta):
    s_max_1 = int(np.round(np.log(max_t) / np.log(eta))) + 1
    get_n0 = lambda s: int(np.ceil(s_max_1 / (s + 1) * eta ** s))
    get_r0 = lambda s: int((max_t * eta ** (-s)))
    hb = {}
    cumul_n0 = 0
    budget = []
    cumul_b = 0
    s_max = s_max_1 - 1
    for s in range(s_max, -1, -1):
        if get_r0(s) == 0:
            s_max -= 1
            continue
        n = [get_n0(s)]
        r = [get_r0(s)]
        cumul_r = r[0]
        cumul_b += n[0] * r[0]
        budget.append(cumul_b)
        temp_n = n[0]
        temp_r = r[0]
        for i in range(s):
            temp_n /= eta
            temp_n = int(np.ceil(temp_n))
            temp_r *= eta
            temp_r = int(min(temp_r, max_t - cumul_r))
            cumul_b += (temp_r - cumul_r) * temp_n
            budget.append(cumul_b)
            cumul_r = temp_r
            n.append(temp_n)
            r.append(temp_r)

        hb[str(s)] = {"n": n, "r": r}
        cumul_n0 += n[0]
    return hb, budget, cumul_n0, s_max


def get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps, num_training_steps, decay_rate, last_epoch=-1
):
    """
    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0,
    after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
    Args:
        optimizer (:class:`~torch.optim.Optimizer`):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (:obj:`int`):
            The number of steps for the warmup phase.
        num_training_steps (:obj:`int`):
            The total number of training steps.
        last_epoch (:obj:`int`, `optional`, defaults to -1):
            The index of the last epoch when resuming training.
    Return:
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0,
            float(num_training_steps - current_step +
                  decay_rate * (current_step - num_warmup_steps))
            / float(max(1, num_training_steps - num_warmup_steps)),
        )

    return LambdaLR(optimizer, lr_lambda, last_epoch)
