import os
os.environ['HF_HOME'] = "../llama_on_glue/checkpoints"
from huggingface_hub import login
#
from datasets import load_dataset

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from peft import PeftModel
import torch

# def download_m(model_id):
#     dtype = torch.bfloat16
#     print(model_id)
#     tokenizer = AutoTokenizer.from_pretrained(model_id)
#     model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
#     del model, tokenizer
#     torch.cuda.empty_cache()
# download_m('Qwen/Qwen2.5-7B-Instruct')

from models_lots_of_loras import dataset_list
for i in dataset_list:
    try:
        load_dataset(i, split='test')
        id = i.find('/task')+len('/task')
        id = i[id:].split('_')[0]
        id = int(id)
        model = f"Lots-of-LoRAs/Mistral-7B-Instruct-v0.2-4b-r16-task{id}"
        PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2'), model)
    except:
        print(i)
# for i in list(lots_of_lora.keys())[1:]:
#     dataset = lots_of_lora[i]['datasets'][0]
#     a = load_dataset(dataset, split='test')

# from models_and_datas import models_and_datas
# for i in models_and_datas:
#     # if 0==len(models_and_datas[i]['datasets']):
#     #     continue
#     for m in models_and_datas[i]['model']:
#         download_m(m)
    # for d in models_and_datas[i]['datasets']:
    #     download_d(d)

# model_ids = ["meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-Guard-2-8B", 
#              "MLP-KTLim/llama-3-Korean-Bllossom-8B", "hfl/llama-3-chinese-8b-instruct-v3", 
#              "DeepMount00/Llama-3-8b-Ita", "hkust-nlp/dart-math-llama3-8b-prop2diff",
#              "MonteXiaofeng/CareBot_Medical_multi-llama3-8b-base", "meta-llama/Meta-Llama-3-8B-Instruct"]

# device = "cpu"
# dtype = torch.bfloat16
# for model_id in model_ids:
#     print(model_id)
#     tokenizer = AutoTokenizer.from_pretrained(model_id)
#     model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device)
#     del model, tokenizer

