from transformers import (
    PreTrainedModel, AutoModelForCausalLM, AlbertModel, AlbertConfig, DebertaModel, DebertaConfig, LlamaConfig,
    DebertaV2Config, DebertaV2Model, 
    ElectraModel, ElectraConfig, RobertaModel, RobertaConfig,
    load_tf_weights_in_electra, AutoConfig, AutoTokenizer, pipeline
)
import torch

model_name = "gpt2-xl"

model = AutoModelForCausalLM.from_pretrained(model_name)
        
tokenizer = AutoTokenizer.from_pretrained(model_name)

if "llama2" in model_name:
    tokenizer.pad_token = '<PAD>'
    tokenizer.sep_token = '<SEP>'
    tokenizer.padding_side='left'
if "gpt2" in model_name:
    tokenizer.pad_token = tokenizer.eos_token
input = tokenizer("<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Considering the fact that <farmland at location countryside>. \
                  James was looking for a good place to buy farmland.  Where might he look? countryside <SEP> Considering the fact that <danielle darrieux nationality english>, The mother tongue of Danielle Darrieux is ")
output = model.generate(input_ids = torch.tensor([input.input_ids]), attention_mask = torch.tensor([input.attention_mask]), max_length=len(input.input_ids)+5)
print("!"*10, tokenizer.decode(output[0], skip_special_padding=False))


class Model_causal():
    def __init__(self, model_name, config):
        # self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if "llama2" in model_name:
            self.tokenizer.pad_token = '<PAD>'
            self.tokenizer.sep_token = '<SEP>'
            self.tokenizer.padding_side='left'
        if "gpt2" in model_name:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.input = self.tokenizer("<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Considering the fact that <farmland at location countryside>. \
                                    James was looking for a good place to buy farmland.  Where might he look? countryside <SEP> Considering the fact that <danielle darrieux nationality english>, The mother tongue of Danielle Darrieux is ")
        output = self.model.generate(input_ids = torch.tensor([self.input.input_ids]), attention_mask = torch.tensor([self.input.attention_mask]), max_length=len(self.input.input_ids)+5)
        print("!"*10, self.tokenizer.decode(output[0], skip_special_padding=False))


        model_name = "gpt2-xl"
        model = AutoModelForCausalLM.from_pretrained(model_name)
        
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        if "llama2" in model_name:
            tokenizer.pad_token = '<PAD>'
            tokenizer.sep_token = '<SEP>'
            tokenizer.padding_side='left'
        if "gpt2" in model_name:
            tokenizer.pad_token = tokenizer.eos_token
        input = tokenizer("<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Considering the fact that <farmland at location countryside>. \
                          James was looking for a good place to buy farmland.  Where might he look? countryside <SEP> Considering the fact that <danielle darrieux nationality english>, The mother tongue of Danielle Darrieux is ")
        output = model.generate(input_ids = torch.tensor([input.input_ids]), attention_mask = torch.tensor([input.attention_mask]), max_length=len(input.input_ids)+5)
        print("!"*10, tokenizer.decode(output[0], skip_special_padding=False))



lm_config = AutoConfig.from_pretrained(model_name)
model = Model_causal(model_name, lm_config)

output = model.model.generate(input_ids = torch.tensor([input.input_ids]), attention_mask = torch.tensor([input.attention_mask]), max_length=len(input.input_ids)+5)
print("!"*10, tokenizer.decode(output[0], skip_special_padding=False))


output = model.model.generate(input_ids = torch.tensor([model.input.input_ids]), attention_mask = torch.tensor([model.input.attention_mask]), max_length=len(model.input.input_ids)+5)
print("!"*10, tokenizer.decode(output[0], skip_special_padding=False))

print(len(model.input.input_ids))
print(len(input.input_ids))