from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import math
import torch
import torch.distributed as dist
from torch.optim.lr_scheduler import LambdaLR
from datetime import datetime


def get_time():
    return datetime.now().strftime("%m%d_%H%M%S")


class AllGather(torch.autograd.Function):
    """An autograd function that performs allgather on a tensor."""

    @staticmethod
    def forward(ctx, tensor, args):
        output = [torch.empty_like(tensor) for _ in range(args.world_size)]
        dist.all_gather(output, tensor)
        ctx.rank = args.rank
        ctx.batch_size = tensor.shape[0]
        return torch.cat(output, 0)

    @staticmethod
    def backward(ctx, grad_output):
        return (
            grad_output[ctx.batch_size * ctx.rank:ctx.batch_size * (ctx.rank + 1)],
            None,
        )


def get_model_name(args):
    if args.extract_ilp:
        model_name = '{}_all_{}'.format(args.dataset_name, args.task_name)
    else:
        model_name = '{}_train_{}'.format(args.dataset_name, args.task_name)

    model_name += '_{}layer'.format(args.num_layers)

    model_name += '_re{}_atthead{}_ffdim{}_rnndim{}'.format(args.resample_lowerbound,
                                                            args.text_att_n_head,
                                                            args.text_feedforward_dim,
                                                            args.hidden_dim)

    model_name += '_e{}_lr{}_batch{}_dt'.format(args.epochs, args.learning_rate_in_float, args.train_batch_size)
    model_name += get_time()
    return model_name


'''
code from: https://huggingface.co/docs/transformers/main_classes/optimizer_schedules
'''
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases following the
    values of the cosine function between 0 and `pi * cycles` after a warmup
    period during which it increases linearly between 0 and 1.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)
