import torch
import torch.nn as nn
import torch.optim as optim

class adam_emb_wd(optim.Adam):
    def __init__(self, model, lr=0.001, wd=0.003, wd_scale=0.01):
        # Separate parameters for embedding and non-embedding layers
        embedding_params = []
        other_params = []

        for name, param in model.named_parameters():
            if 'embedding' in name.lower() and param.requires_grad:
                embedding_params.append(param)
            elif param.requires_grad:
                other_params.append(param)

        # Define the parameter groups for the optimizer
        param_groups = [
            {'params': embedding_params, 'weight_decay': wd_scale * wd},  # No L2 weight decay on embedding
            {'params': other_params, 'weight_decay': wd}
        ]

        super().__init__(param_groups, lr=lr)
        # self.embedding_params = embedding_params
    def step(self, closure=None):
        # Call the original step method to update parameters
        super().step(closure)


def lasso_proximal_step(param, threshold):
    # Apply the soft-thresholding operator to implement L1 regularization (Lasso)
    return torch.sign(param) * torch.clamp(torch.abs(param) - threshold, min=0.0)

from torch.optim import Optimizer

class CoordinatedOptimizer(Optimizer):
    def __init__(self, embedding_params, other_params, lr_embedding, lr_other, wd_embedding, wd_other):
        """
        Custom optimizer for coordinated updates.
        :param embedding_params: List of parameters for embeddings.
        :param other_params: List of other parameters.
        :param lr_embedding: Learning rate for embedding parameters.
        :param lr_other: Learning rate for other parameters.
        :param wd_embedding: Weight decay for embedding parameters.
        :param wd_other: Weight decay for other parameters.
        """
        param_groups = [
            {'params': embedding_params, 'lr': lr_embedding, 'weight_decay': wd_embedding},
            {'params': other_params, 'lr': lr_other, 'weight_decay': wd_other}
        ]
        defaults = {"lr_embedding": lr_embedding, "lr_other": lr_other, 
                    "wd_embedding": wd_embedding, "wd_other": wd_other}
        super(CoordinatedOptimizer, self).__init__(param_groups, defaults)
        self.update_embedding = True  # Internal flag for coordinated updates

    def step(self, closure=None):
        """
        Perform a single optimization step, updating only embedding or other parameters based on the internal flag.
        :param closure: A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            if self.update_embedding:
                # Update embedding parameters only
                if group['lr'] == self.defaults["lr_embedding"]:
                    for param in group['params']:
                        if param.grad is not None:
                            grad = param.grad.data
                            param.data -= group['lr'] * (grad + group['weight_decay'] * param.data)
    
            else:
                # Update other parameters only
                if group['lr'] == self.defaults["lr_other"]:
                    for param in group['params']:
                        if param.grad is not None:
                            grad = param.grad.data
                            param.data -= group['lr'] * (grad + group['weight_decay'] * param.data)


        # Toggle the update flag
        self.update_embedding = not self.update_embedding
        return loss