import torch
from torch import nn
from splora import SpLoRaModel
from splora import ScaledAdamW, SpLoRaLinear, AdamWFD
from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor
import bitsandbytes as bnb
from peft_pretraining import training_utils
import transformers
from loguru import logger


def build_model(model, args):
    if args.peft_model.lower() == 'splora':
        model = SpLoRaModel(
            model,
            r=args.rank,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            target_modules=args.target_modules,
            sp_ratio=args.sp_ratio,
            sp_type=args.sp_type,
            trainable_scaling=args.train_scaling,
            random_subspace=args.random_subspace
        )
    return model


def build_optimizer(model, trainable_params, args):
    if args.peft_model.lower() == 'galore':
        if not 'galore' in args.optimizer.lower():
            raise ValueError(f"Galore is used as peft model but optimizer is {args.optimizer.lower()}, which does not "
                             f"contain galore")

    if 'galore' in args.optimizer.lower():
        galore_params = []
        # make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
        for module_name, module in model.named_modules():
            if not isinstance(module, nn.Linear):
                continue

            if not any(target_key in module_name for target_key in args.target_modules):
                continue

            print('enable GaLore for weights in module: ', module_name)
            galore_params.append(module.weight)
        id_galore_params = [id(p) for p in galore_params]
        # make parameters without "rank" to another group
        regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
        # then call galore_adamw
        param_groups = [{'params': regular_params},
                        {'params': galore_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap,
                         'scale': args.galore_scale, 'proj_type': args.proj_type}]
        # logger.info(f"Total params with GaLore enabled: {sum(p.numel() for p in galore_params) / 1_000_000:.2f}M")



        if args.optimizer.lower() == "galore_adamw":
            # redefine way to call galore_adamw
            optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer.lower() == "galore_adafactor":
            args.beta1 = None if args.beta1 == 0.0 else args.beta1
            optimizer = GaLoreAdafactor(
                param_groups,
                lr=args.lr,
                eps=(1e-30, 1e-3),
                clip_threshold=1.0,
                decay_rate=-0.8,
                beta1=args.beta1,
                weight_decay=args.weight_decay,
                relative_step=False,
                scale_parameter=False,
                warmup_init=False,
            )
        elif args.optimizer.lower() == "galore_adamw8bit":
            optimizer = GaLoreAdamW8bit(param_groups, lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer.lower() == 'galore_adamw8bit_per_layer':
            # TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap
            optimizer = {}
            for p in model.parameters():
                if p.requires_grad:
                    if id(p) in id_galore_params:
                        optimizer[p] = GaLoreAdamW8bit([{'params': [p], 'rank': args.rank,
                                                              'update_proj_gap': args.update_proj_gap * 2,
                                                              'scale': args.galore_scale, 'proj_type': args.proj_type}],
                                                            lr=args.lr, weight_decay=args.weight_decay)
                    else:
                        optimizer[p] = bnb.optim.Adam8bit([p], lr=args.lr, weight_decay=args.weight_decay)

            # get scheduler dict
            scheduler_dict = {}
            for p in model.parameters():
                if p.requires_grad:
                    scheduler_dict[p] = training_utils.get_scheculer(
                        optimizer=optimizer[p],
                        scheduler_type=args.scheduler,
                        num_training_steps=args.num_training_steps * 2,
                        warmup_steps=args.warmup_steps * 2,
                        min_lr_ratio=args.min_lr_ratio,
                    )

            def optimizer_hook(p):
                if p.grad is None:
                    return
                optimizer[p].step()
                optimizer[p].zero_grad()
                scheduler_dict[p].step()

            # Register the hook onto every parameter
            for p in model.parameters():
                if p.requires_grad:
                    p.register_post_accumulate_grad_hook(optimizer_hook)

        else:
            raise ValueError(f"Optimizer {args.optimizer} not supported")


    elif args.peft_model.lower() == 'splora' and args.precondition:
        lora_params = []
        for module_name, module in model.named_modules():
            if isinstance(module, SpLoRaLinear):
                lora_params.append(module.lora_B)
                lora_params.append(module.lora_A)
        id_lora_params = [id(p) for p in lora_params]
        regular_params = [p for p in model.parameters() if id(p) not in id_lora_params and p.requires_grad]
        param_groups = [{'params': regular_params},
                        {'params': lora_params, 'lora': True, 'precon_type': args.precon_type}]

        if args.optimizer.lower() == "adamw":
            optimizer = ScaledAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

    elif args.peft_model.lower() == 'splora' and args.f_decay > 0:
        lora_params = []
        for module_name, module in model.named_modules():
            if isinstance(module, SpLoRaLinear):
                lora_params.append(module.lora_B)
                lora_params.append(module.lora_A)
        id_lora_params = [id(p) for p in lora_params]
        regular_params = [p for p in model.parameters() if id(p) not in id_lora_params and p.requires_grad]
        param_groups = [{'params': regular_params},
                        {'params': lora_params, 'lora': True}]

        if args.optimizer.lower() == "adamw":
            optimizer = AdamWFD(param_groups, lr=args.lr, weight_decay=args.weight_decay, f_decay=args.f_decay)
        else:
            raise NotImplementedError

    # elif args.peft_model.lower() == 'splora' and not (args.precondition or args.f_decay > 0):
    #     lora_params = [p for n,p in model.named_parameters() if 'lora_' in n and p.requires_grad]
    #     id_lora_params = [id(p) for p in lora_params]
    #     regular_params = [p for p in model.parameters() if id(p) not in id_lora_params and p.requires_grad]
    #     param_groups = [{'params': regular_params, 'weight_decay': args.weight_decay, 'lr': args.lr},
    #                     {'params': lora_params, 'weight_decay': args.weight_decay_lora, 'lr': args.lr_lora}]
    #     if args.optimizer.lower() == "adamw":
    #         optimizer = torch.optim.AdamW(param_groups)
    #     else:
    #         raise NotImplementedError

    else:
        if args.optimizer.lower() == "adam":
            optimizer = torch.optim.Adam(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer.lower() == "adamw":
            optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer.lower() == "sgd":
            optimizer = torch.optim.SGD(trainable_params, lr=args.lr, weight_decay=args.weight_decay, momentum=args.beta1)
        # implement adafactor
        elif args.optimizer.lower() == "adafactor":
            args.beta1 = None if args.beta1 == 0.0 else args.beta1
            optimizer = transformers.optimization.Adafactor(
                trainable_params,
                lr=args.lr,
                eps=(1e-30, 1e-3),
                clip_threshold=1.0,
                decay_rate=-0.8,
                beta1=args.beta1,
                weight_decay=args.weight_decay,
                relative_step=False,
                scale_parameter=False,
                warmup_init=False,
            )
        elif args.optimizer.lower() == "adam8bit":
            optimizer = bnb.optim.Adam8bit(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer.lower() == "adam8bit_per_layer":
            optimizer = {}
            for p in model.parameters():
                if p.requires_grad:
                    optimizer[p] = bnb.optim.Adam8bit([p], lr=args.lr, weight_decay=args.weight_decay)
            # get scheduler dict
            scheduler_dict = {}
            for p in model.parameters():
                if p.requires_grad:
                    scheduler_dict[p] = training_utils.get_scheculer(
                        optimizer=optimizer[p],
                        scheduler_type=args.scheduler,
                        num_training_steps=args.num_training_steps * 2,
                        warmup_steps=args.warmup_steps * 2,
                        min_lr_ratio=args.min_lr_ratio,
                    )
            def optimizer_hook(p):
                if p.grad is None:
                    return
                optimizer[p].step()
                optimizer[p].zero_grad()
                scheduler_dict[p].step()

            # Register the hook onto every parameter
            for p in model.parameters():
                if p.requires_grad:
                    p.register_post_accumulate_grad_hook(optimizer_hook)
        else:
            raise ValueError(f"Optimizer {args.optimizer} not supported")


    return optimizer
