import torch
import transformers
from foam_torch import GaLoreAdamW, GaLoreAdamW8bit, Muon, APOLLO, FOAMAdamW
import bitsandbytes as bnb

def configure_optimizer(args, logger, model, model_config):

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    num_total_params = sum(p.numel() for p in model.parameters())
    num_trainable_params = sum(p.numel() for p in trainable_params)
    logger.info(f"\n{model}\n")
    logger.info(f"Total params: {num_total_params / 1_000_000:.2f}M")
    logger.info(f"Trainable params: {num_trainable_params / 1_000_000:.2f}M")

    def get_optimizer_params(model):
        return [
            p for name, p in model.named_parameters()
            if p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
        ], [
            p for name, p in model.named_parameters()
            if not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name)
        ]
    
    if args.optimizer.lower() == "adamw":
        optimizer = torch.optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2))

    elif args.optimizer.lower() == "adafactor":
        args.beta1 = None if args.beta1 == 0.0 else args.beta1
        optimizer = transformers.optimization.Adafactor(
            trainable_params,
            lr=args.lr,
            eps=(1e-30, 1e-3),
            clip_threshold=1.0,
            decay_rate=-0.8,
            beta1=args.beta1,
            weight_decay=args.weight_decay,
            relative_step=False,
            scale_parameter=False,
            warmup_init=False,
        )

    elif args.optimizer.lower() == "galore_adamw":
        scale_params, base_params = get_optimizer_params(model)
        param_groups = [{'params': base_params}, 
                        {'params': scale_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.scale, 'proj_type': args.proj_type}]
        
        optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer.lower() == "apollo" or args.optimizer.lower() == "apollo_mini":
        scale_params, base_params = get_optimizer_params(model)
        param_groups = [{'params': base_params}, 
                        {'params': scale_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.scale, 'proj_type': args.proj_type, 'proj': 'random', 'scale_type': 'tensor'}]
        
        optimizer = APOLLO(param_groups, lr=args.lr, weight_decay=args.weight_decay)

    elif args.optimizer.lower() == "foam":
        scale_params, base_params = get_optimizer_params(model)
        param_groups = [{'params': base_params}, 
                        {'params': scale_params, 'scale': args.scale, 'level': args.level}]

        if 2 ** args.level > model_config.hidden_size:
            logger.info('Using compress level larger than model dimension, will adopt padding')
        optimizer = FOAMAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay, \
                               res_scale=args.res_scale, no_norm_limit=args.no_norm_limit, warmup_steps=args.warmup_steps)

    elif args.optimizer.lower() == "muon":
        scale_params, base_params = get_optimizer_params(model)
        logger.info(f"Total params with Muon enabled: {sum(p.numel() for p in scale_params) / 1_000_000:.2f}M")
        
        optimizer = Muon(
            lr=args.lr,
            ns_steps=args.n_steps,
            wd=args.weight_decay,
            muon_params=scale_params,
            adamw_params=base_params,
        )

    # 8-bit Adam
    elif args.optimizer.lower() == "adam8bit":
        optimizer = bnb.optim.Adam8bit(trainable_params, lr=args.lr, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2))
    elif args.optimizer.lower() == "galore_adamw8bit":
        optimizer = GaLoreAdamW8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError(f"Optimizer {args.optimizer} not supported")

    optimizer_memory = estimate_memory(compressed_params=scale_params, base_params=base_params, args=args, logger=logger)

    logger.info(f"Total Training Memory Consumption: {(optimizer_memory + num_trainable_params * 2) / 1_000_000:2f}M")

    return optimizer


def estimate_memory(compressed_params, base_params, args, logger):
    project_bytes = 0
    compressed_opt_bytes = 0
    regular_opt_bytes = 0

    if 'apollo' in args.optimizer.lower() or 'galore' in args.optimizer.lower():
        project_bytes = sum(min(p.shape[0], p.shape[1]) * args.rank for p in compressed_params) * 2
        compressed_opt_bytes = sum(2 * max(p.shape[0], p.shape[1]) * args.rank for p in compressed_params) * 2
        regular_opt_bytes = sum(p.numel() for p in base_params) * 2 * 2

    if 'foam' in args.optimizer.lower():
        project_bytes = 0
        compressed_opt_bytes = sum(2 * p.numel() for p in compressed_params) * 2 / (2 ** args.level)
        regular_opt_bytes = sum(p.numel() for p in base_params) * 2 * 2
    
    if 'muon' in args.optimizer.lower() or 'adamw_mini' in args.optimizer.lower():
        project_bytes = 0
        compressed_opt_bytes = sum(p.numel() for p in compressed_params) * 2
        regular_opt_bytes = sum(p.numel() for p in base_params) * 2 * 2

    logger.info(f"Total params with {args.optimizer} enabled: {sum(p.numel() for p in compressed_params) / 1_000_000:.2f}M")
    logger.info(f"Compressed_opt_bytes: {(compressed_opt_bytes) / 1_000_000:2f}M")
    logger.info(f"Projection_bytes: {(project_bytes) / 1_000_000:2f}M")
    logger.info(f"Total Oprtimizer Memory Consumption: {(regular_opt_bytes + compressed_opt_bytes + project_bytes) / 1_000_000:2f}M")

    return regular_opt_bytes + compressed_opt_bytes + project_bytes