from huggingface_hub import hf_hub_download, list_repo_files
from transformers import (
    AutoModelForImageClassification, 
    AutoModelForSequenceClassification, 
    AutoTokenizer, 
    LlamaForSequenceClassification,
    LlamaTokenizer,
    BitsAndBytesConfig,
    T5ForConditionalGeneration,
    logging
)
import torch
from transformers.utils import quantization_config

from utils.lora_utils import add_adapters

def _bnb_config(load_in_4bit=True):
    return BitsAndBytesConfig(
        load_in_4bit      = load_in_4bit,
        bnb_4bit_quant_type = "nf4",
        bnb_4bit_compute_dtype = torch.bfloat16,
    )

def build_model(dataset, quantize=False, load_in_4bit=True): 
    logging.set_verbosity_error()

    tokenizer = None
    if dataset.startswith('glue_'): 
        model_name = "meta-llama/Llama-3.2-1B"
        num_labels = 1 
        task = dataset.split("_", 1)[1]
        if task == "stsb":
            num_labels = 1          # regression
        elif task.startswith("mnli"):
            num_labels = 3          # contradiction / entailment / neutral
        else:
            num_labels = 2
        quant_cfg = _bnb_config(load_in_4bit) if quantize else None
        model = LlamaForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, quantization_config=quant_cfg)
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.pad_token_id
    else: 
      if dataset == 'cifar10' or dataset == 'svhn':
          model_name = "google/vit-base-patch16-224-in21k"
          model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=10)
      elif dataset == 'cifar100': 
          model_name = "google/vit-base-patch16-224-in21k"
          model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=100)
      elif dataset == '20newsgroups': 
          model_name = "t5-base"
          model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=20)
          tokenizer = AutoTokenizer.from_pretrained(model_name)
      elif dataset == 'mrqa': 
          model_name = "t5-base"
          model = T5ForConditionalGeneration.from_pretrained(model_name)
          tokenizer = AutoTokenizer.from_pretrained(model_name)
      else: 
          raise NotImplementedError()
    
    return model, tokenizer

def copy_model(model, dataset, rank, alpha, b_var, r_var, a_var, num_heads, adaptation_method): 
    copy, _ = build_model(dataset) 
    copy = add_adapters(copy, rank, alpha, b_var, r_var, a_var, num_heads, adaptation_method)
    state = model.state_dict()
    copy.load_state_dict(state, strict=False)
    copy = copy.cuda() 
    
    return copy

def count_parameters(model): 
    total, trainable = 0, 0
    for name, parameter in model.named_parameters(): 
        if parameter.requires_grad: 
            trainable += torch.numel(parameter.data)
        total += torch.numel(parameter.data)
    
    return total, trainable

def models_are_equal(model1, model2, rtol=1e-05, atol=1e-08):
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()
    
    # Check if the models have the same set of keys
    if state_dict1.keys() != state_dict2.keys():
        print("State dictionaries have different keys:")
        print("Keys in model1 but not in model2:", state_dict1.keys() - state_dict2.keys())
        print("Keys in model2 but not in model1:", state_dict2.keys() - state_dict1.keys())
        return False

    # Check each parameter tensor
    for key in state_dict1:
        tensor1 = state_dict1[key]
        tensor2 = state_dict2[key]
        
        if tensor1.shape != tensor2.shape:
            print(f"Shape mismatch for parameter '{key}': {tensor1.shape} vs {tensor2.shape}")
            return False
        
        if not torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol):
            diff = (tensor1 - tensor2).abs().max().item()
            print(f"Parameter '{key}' differs (max difference: {diff})")
            return False
    
    return True

def get_embedding_dim(model): 
    for name, parameter in model.named_parameters(): 
        if 'lora_B' in name: 
            embed_dim = model.state_dict()[name].shape[0]
            return embed_dim 

