import torch
import transformers
from peft_pretraining import training_utils
from transformers.pytorch_utils import Conv1D
import torch.nn as nn
from optim import ProjFactor, Galore

def optimizer_constructor(args, logger, model):
    if args.optimizer.lower() == "projfactor":
        optimizer = ProjFactor( 
            model,
            lr=args.lr,
            weight_decay=args.weight_decay,
            rank=args.rank,
            update_proj_gap=args.update_proj_gap,
            scale=args.scale,
            factor=args.factor,
            scheduler=args.scheduler,
            gradient_accumulation=args.gradient_accumulation,
            warmup_steps=args.warmup_steps,
            num_training_steps=args.num_training_steps,
            min_lr_ratio=args.min_lr_ratio,
            correct_bias=True,
            proj_matrix_dist=args.proj_matrix_dist,
        )
        return optimizer, None

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    if args.optimizer.lower() == "adam":
        optimizer = torch.optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer.lower() == "adafactor":
        args.beta1 = None if args.beta1 == 0.0 else args.beta1
        optimizer = transformers.optimization.Adafactor(
            trainable_params,
            lr=args.lr,
            eps=(1e-30, 1e-3),
            clip_threshold=1.0,
            decay_rate=-0.8,
            beta1=args.beta1,
            weight_decay=args.weight_decay,
            relative_step=False,
            scale_parameter=False,
            warmup_init=False,
        )
    else:
        raise ValueError(f"Invalid optimizer: {args.optimizer}")
    
    scheduler = training_utils.get_scheculer(
            optimizer=optimizer,
            scheduler_type=args.scheduler,
            num_training_steps=args.num_training_steps,
            warmup_steps=args.warmup_steps,
            min_lr_ratio=args.min_lr_ratio,
        )
    return optimizer, scheduler