import math

import torch
import torch.nn as nn
from torch.optim import AdamW


def get_learning_rate_schedule(scheduler_config):
    def lr_lambda(current_step: int):

        total_training_steps = scheduler_config.num_training_steps - 1

        training_steps_after_warmup = total_training_steps - scheduler_config.num_warmup_steps

        if current_step < scheduler_config.num_warmup_steps:
            return float(current_step) / float(max(1, scheduler_config.num_warmup_steps))
        elif scheduler_config.schedule == 'linear':
            linear_decay = max(0.0,
                               float(total_training_steps - current_step) / float(max(1, training_steps_after_warmup)))
            return scheduler_config.decay_factor + (1 - scheduler_config.decay_factor) * linear_decay
        elif scheduler_config.schedule == 'cosine':
            cosine_decay = max(0.0, (1 + math.cos(math.pi * (current_step - scheduler_config.num_warmup_steps) / float(
                max(1, training_steps_after_warmup)))) / 2)
            return scheduler_config.decay_factor + (1 - scheduler_config.decay_factor) * cosine_decay
        elif scheduler_config.schedule == 'const':
            return 1.0
        else:
            raise ValueError(f"Unknown schedule: {scheduler_config.schedule}")

    return lr_lambda


def group_parameters_for_optimizer(model, optimizer_cfg, vocab_size, regularize_embedding,
                                   regularize_head, logger):
    if 'weight_decay' in optimizer_cfg:
        weight_decay = optimizer_cfg.weight_decay
    else:
        weight_decay = 0.0

    if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()):
        return model.parameters()

    skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set()
    skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords')
                     else set())

    decay = set()
    no_decay = set()
    special = set()

    whitelist_weight_modules = (nn.Linear,)

    blacklist_weight_modules = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
                                nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
                                nn.GroupNorm, nn.SyncBatchNorm,
                                nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
                                nn.LayerNorm, nn.LocalResponseNorm)

    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn
            if not p.requires_grad or fpn not in param_dict:
                continue
            if hasattr(p, '_optim'):
                special.add(fpn)
            elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords):
                no_decay.add(fpn)
            elif getattr(p, '_no_weight_decay', False):
                no_decay.add(fpn)

            elif pn.endswith('bias'):
                no_decay.add(fpn)
            elif len(p.shape) <= 1:
                no_decay.add(fpn)
            elif regularize_embedding and isinstance(m, (nn.Embedding)):
                decay.add(fpn)
            elif not regularize_embedding and isinstance(m, (nn.Embedding)):
                no_decay.add(fpn)
            elif regularize_embedding and ("embed" in fpn or "wte" in fpn):
                decay.add(fpn)
            elif not regularize_embedding and ("embed" in fpn or "wte" in fpn):
                no_decay.add(fpn)
            elif regularize_head and ("head_linear" in fpn or "lm_head" in fpn):
                decay.add(fpn)
            elif not regularize_head and ("head_linear" in fpn or "lm_head" in fpn):
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                decay.add(fpn)
            elif isinstance(m, blacklist_weight_modules):
                no_decay.add(fpn)

    decay |= (param_dict.keys() - no_decay - special)
    inter_params = decay & no_decay
    union_params = decay | no_decay
    assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!"
    assert len(
        param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)}  were not separated into either decay/no_decay set!"

    if weight_decay == 0.0 or not no_decay:
        param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))],
                         "weight_decay": weight_decay}]
    else:
        param_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]

    logger.info(f"Parameters with weight decay: {sorted(list(decay))}")
    logger.info(f"Parameters without weight decay<: {sorted(list(no_decay))}")

    hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)]
    for hp in hps:
        params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp]
        param_groups.append({"params": params, **hp})

    return param_groups


def configure_optimizer(cfg_optim, model, vocab_size, logger):
    parameters = group_parameters_for_optimizer(model, cfg_optim, vocab_size=vocab_size,
                                                regularize_embedding=cfg_optim.regularize_embedding,
                                                regularize_head=cfg_optim.regularize_head, logger=logger)

    optimizer = AdamW(parameters, lr=cfg_optim.lr, betas=cfg_optim.betas, weight_decay=cfg_optim.weight_decay,
                      fused=True)

    for i, g in enumerate(optimizer.param_groups):
        ntensors = len(g['params'])
        nparams = sum(p.numel() for p in g['params'])
        hparams = {k: v for k, v in g.items() if k != 'params'}
        logger.info(f'Optimizer group {i}: {ntensors} tensors, {nparams} parameters, {hparams}')

    lr_lambda = get_learning_rate_schedule(cfg_optim.scheduler)
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
    return optimizer, lr_scheduler
