import torch
import transformers
from sron_torch import GaLoreAdamW, GaLoreAdamW8bit, Muon, APOLLO, SRON, SRONSGD, SignSGD, Lion, SGD
import bitsandbytes as bnb


def get_optimizer_params(model, logger):
    config = model.config.to_dict()
    scale_params = []
    base_params = []
    if "model_type" in config:
        model_type = getattr(config, "model_type", "default")

    exclude_keywords = {
        "gpt2": ["wte", "wpe", "lm_head"],
        "llama": ["embed_tokens", "lm_head"],
        "gemma": ["embed_tokens", "lm_head"],
        "mistral": ["embed_tokens", "lm_head"],
        "deberta-v2": ["embeddings", "cls", "pooler", "classifier"],
        "default": ["embed", "embeddings", "cls", "pooler", "classifier", "lm_head", "wte", "wpe", "embed_tokens"]
    }

    if model_type in exclude_keywords:
        keywords = exclude_keywords[model_type]
    else:
        keywords = exclude_keywords["default"]

    for name, p in model.named_parameters():
        if p.ndim >= 2 and not any(k in name for k in keywords):
            scale_params.append(p)
            logger.info(f'enable epcecified optimizer for weights in module: {name}')
        else:
            base_params.append(p)

    return scale_params, base_params

def configure_optimizer(args, logger, model, model_config):

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    num_total_params = sum(p.numel() for p in model.parameters())
    num_trainable_params = sum(p.numel() for p in trainable_params)
    
    base_params = trainable_params
    scale_params = []
    if args.optimizer.lower() == "adamw":
        
        optimizer = torch.optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay)

    elif args.optimizer.lower() == "adafactor":
        args.momentum = None if args.momentum == 0.0 else args.momentum
        optimizer = transformers.optimization.Adafactor(
            trainable_params,
            lr=args.lr,
            eps=(1e-30, 1e-3),
            clip_threshold=1.0,
            decay_rate=-0.8,
            beta1=args.momentum,
            weight_decay=args.weight_decay,
            relative_step=False,
            scale_parameter=False,
            warmup_init=False,
        )

    elif args.optimizer.lower() == "galore_adamw":
        scale_params, base_params = get_optimizer_params(model, logger)
        param_groups = [{'params': base_params}, 
                        {'params': scale_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.scale, 'proj_type': args.proj_type}]
        
        optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer.lower() == "apollo" or args.optimizer.lower() == "apollo_mini":
        scale_params, base_params = get_optimizer_params(model, logger)
        if args.optimizer.lower() == "apollo_mini": scale_type = 'tensor'
        else: scale_type = 'channel'

        param_groups = [{'params': base_params}, 
                        {'params': scale_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.scale, 'proj_type': args.proj_type, 'proj': 'random', 'scale_type': scale_type}]
        
        optimizer = APOLLO(param_groups, lr=args.lr, weight_decay=args.weight_decay)

    # implement sgd
    elif args.optimizer.lower() == "sgd":
        scale_params, base_params = get_optimizer_params(model, logger)
        optimizer = SGD(trainable_params, lr=args.lr, wd=args.weight_decay, momentum=args.momentum)

    elif args.optimizer.lower() == "signsgd":
        scale_params, base_params = get_optimizer_params(model, logger)
        optimizer = SignSGD(trainable_params, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)

    elif args.optimizer.lower() == "lion":
        scale_params, base_params = get_optimizer_params(model, logger)
        optimizer = Lion(trainable_params, lr=args.lr, weight_decay=args.weight_decay)

    elif args.optimizer.lower() == "muon":
        scale_params, base_params = get_optimizer_params(model, logger)
        logger.info(f"Total params with Muon enabled: {sum(p.numel() for p in scale_params) / 1_000_000:.2f}M")
        
        optimizer = Muon(
            lr=args.lr,
            ns_steps=args.n_steps,
            wd=args.weight_decay,
            muon_params=scale_params,
            adamw_params=base_params,
        )
    
    elif args.optimizer.lower() == "sron":
        scale_params, base_params = get_optimizer_params(model, logger)
        logger.info(f"Total params with SGD/M enabled: {sum(p.numel() for p in scale_params) / 1_000_000:.2f}M")
        
        optimizer = SRON(
            lr=args.lr,
            wd=args.weight_decay,
            sgd_params=scale_params,
            adamw_params=base_params,
            momentum=args.momentum,
            scale=args.scale,
        )


    elif args.optimizer.lower() == 'sronsgd':
        scale_params, base_params = get_optimizer_params(model, logger)
        logger.info(f"Total params with White SGD/M enabled: {sum(p.numel() for p in scale_params) / 1_000_000:.2f}M")

        optimizer = SRONSGD(
                trainable_params,
                lr=args.lr,
                wd=args.weight_decay,
                momentum=args.momentum,
        )


    # 8-bit Adam
    elif args.optimizer.lower() == "adam8bit":
        optimizer = bnb.optim.Adam8bit(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer.lower() == "galore_adamw8bit":
        scale_params, base_params = get_optimizer_params(model, logger)
        param_groups = [{'params': base_params}, 
                        {'params': scale_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.scale, 'proj_type': args.proj_type}]
        optimizer = GaLoreAdamW8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError(f"Optimizer {args.optimizer} not supported")

    optimizer_memory = estimate_memory(compressed_params=scale_params, base_params=base_params, args=args, logger=logger)

    logger.info(f"Total params: {num_total_params / 1_000_000:.2f}M")
    logger.info(f"Trainable params: {num_trainable_params / 1_000_000:.2f}M")
    logger.info(f"Total Training Memory Consumption: {(optimizer_memory + num_trainable_params * 2) / 1_000_000:2f}M")
    logger.info(f"\n{model}\n")

    return optimizer


def estimate_memory(compressed_params, base_params, args, logger):
    project_bytes = 0
    compressed_opt_bytes = 0
    regular_opt_bytes = 0
    if args.optimizer.lower() == 'sgd' or args.optimizer.lower() == 'signsgd':
        if args.momentum > 0.0:
            regular_opt_bytes = sum(p.numel() for p in base_params) * 2 + sum(p.numel() for p in compressed_params) * 2
        else:
            project_bytes = 0
            compressed_opt_bytes = 0
            regular_opt_bytes = 0

    if args.optimizer.lower() == 'lion':
        regular_opt_bytes = sum(p.numel() for p in base_params) * 2 + sum(p.numel() for p in compressed_params) * 2

    if 'apollo' in args.optimizer.lower() or 'galore' in args.optimizer.lower():
        project_bytes = sum(min(p.shape[0], p.shape[1]) * args.rank for p in compressed_params) * 2
        compressed_opt_bytes = sum(2 * max(p.shape[0], p.shape[1]) * args.rank for p in compressed_params) * 2
        regular_opt_bytes = sum(p.numel() for p in base_params) * 2 * 2

    if args.optimizer.lower() == 'sron':
        project_bytes = 0
        if args.momentum > 0.0:
            compressed_opt_bytes = sum(p.numel() for p in compressed_params) * 2
        else:
            compressed_opt_bytes = 0

        regular_opt_bytes = sum(p.numel() for p in base_params) * 2 * 2

    if args.optimizer.lower() == 'sronsgd':
        project_bytes = 0
        if args.momentum > 0.0:
            compressed_opt_bytes = sum(p.numel() for p in compressed_params) * 2 + sum(p.numel() for p in base_params) * 2
        else:
            compressed_opt_bytes = 0

        regular_opt_bytes = 0

    elif 'muon' in args.optimizer.lower() or 'adamw_mini' in args.optimizer.lower():
        project_bytes = 0
        compressed_opt_bytes = sum(p.numel() for p in compressed_params) * 2
        regular_opt_bytes = sum(p.numel() for p in base_params) * 2 * 2

    elif 'adafactor' in args.optimizer.lower():
        compressed_opt_bytes = 0
        # torch.zeros(grad_shape[:-1]).to(grad)
        #                 state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
        project_bytes = 0
        compressed_opt_bytes += 2 * sum((p.shape[:-1].numel() + p.shape[:-2].numel() * p.shape[-1]) if p.ndim >= 2 else p.numel() for p in base_params)
        compressed_opt_bytes += 2 * sum((p.shape[:-1].numel() + p.shape[:-2].numel() * p.shape[-1]) if p.ndim >= 2 else p.numel() for p in compressed_params)
        if args.momentum is not None:
            regular_opt_bytes = sum(p.numel() for p in base_params) * 2 + sum(p.numel() for p in compressed_params) * 2
        else:
            regular_opt_bytes = 0
        # regular_opt_bytes = sum(p.shape[-1] for p in base_params) * 2 + sum(p.numel() for p in compressed_params) * 2
    
    logger.info(f"Total params with {args.optimizer} enabled: {sum(p.numel() for p in compressed_params) / 1_000_000:.2f}M")
    logger.info(f"Compressed_opt_bytes: {(compressed_opt_bytes) / 1_000_000:2f}M")
    logger.info(f"Projection_bytes: {(project_bytes) / 1_000_000:2f}M")
    logger.info(f"Total Oprtimizer Memory Consumption: {(regular_opt_bytes + compressed_opt_bytes + project_bytes) / 1_000_000:2f}M")

    return regular_opt_bytes + compressed_opt_bytes + project_bytes