from torch.optim import AdamW
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR



def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, (1 - float(current_step) / float(max(1, num_training_steps)))
        )

    return LambdaLR(optimizer, lr_lambda, last_epoch)



def create_optimizer_and_scheduler(model: nn.Module, num_train_optimization_steps, args):
    
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_train_optimization_steps)

    return optimizer, scheduler

def create_optimizer_and_scheduler_jepa(model: nn.Module, predictor: nn.Module, num_train_optimization_steps, args):
    param_optimizer = list(model.named_parameters()) + list(predictor.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + [name for name, param in predictor.named_parameters()]
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_train_optimization_steps)

    return optimizer, scheduler