import torch
import bitsandbytes as bnb

hf_models = {
    "allenai/macaw-large": {
        "type": "seq2seq"
    },
    "allenai/macaw-3b": {
        "type": "seq2seq"
    },
    "google/t5-small-ssm-nq": {
        "type": "seq2seq"
    },
    "google/t5-large-ssm-nq": {
        "type": "seq2seq"
    },
    "google/flan-t5-small": {
        "type": "seq2seq"
    },
    "google/t5-3b-ssm-nq": {
        "type": "seq2seq"
    },
    "unc-nlp/lxmert-base-uncased": {
        "type": "seq2seq"
    },
    "gpt2-medium": {
        "type": "decoder"
    },
    "sharpbai/Llama-2-7b-chat": {
        "type": "decoder"
    },
    "sharpbai/Llama-2-7b-hf": {
        "type": "decoder"
    },
    "sharpbai/Llama-2-13b-hf": {
        "type": "decoder"
    },
    "NousResearch/Llama-2-7b-chat-hf": {
        "type": "decoder"
    },
    "mistralai/Mistral-7B-Instruct-v0.1": {
        "type": "decoder"
    },
    "NousResearch/Llama-2-70b-chat-hf": {
        "type": "decoder"
    },
    "NousResearch/Llama-2-7b-hf": {
        "type": "decoder"
    },
    "NousResearch/Llama-2-70b-hf": {
        "type": "decoder"
    },
    "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T": {
        "type": "decoder"
    },
    "NousResearch/Llama-2-13b-hf": {
        "type": "decoder"
    },
}

def save_checkpoint(epoch, model, optimizer, filename='checkpoint.ptorch.tar'):
    state = {
        'epoch': epoch+1, 'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }
    torch.save(state, filename)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    return list(lora_module_names)


def print_trainable_parameters(model):
  """
  Prints the number of trainable parameters in the model.
  """
  trainable_params = 0
  all_param = 0
  for _, param in model.named_parameters():
    all_param += param.numel()
    if param.requires_grad:
      trainable_params += param.numel()
  print(
      f"trainable params: {trainable_params} || all params: {all_param} || trainables%: {100 * trainable_params / all_param}"
  )