import os
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
import ssl
ssl._create_default_https_context=ssl._create_unverified_context
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
access_token = "your_huggingface_access_token"
class TransformerModel(nn.Module):
    # name1: mistralai/Mixtral-8x7B-Instruct-v0.1, openai-community/gpt2, HuggingFaceH4/zephyr-7b-beta
    # mistralai/Mistral-7B-v0.1
    def __init__(self, model_name='mistralai/Mistral-7B-v0.1'):
        super(TransformerModel, self).__init__()
        # 判断models/pretrained下是否有预训练模型
        self.model_name = model_name
        safe_name = model_name.replace("/", "_")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            cache_dir=f"models/pretrained/{safe_name}-model",
            low_cpu_mem_usage=True,
            attn_implementation='flash_attention_2',
            token = access_token if access_token else None
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            token=access_token if access_token else None,
            trust_remote_code=True
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token
        if model_name == 'openai-community/gpt2':
            self.tokenizer.padding_side = "left"
    def get_model(self):
        return self.model
    def get_tokenizer(self):
        return self.tokenizer
    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
