import torch
import wandb
from transformers import PreTrainedTokenizer,LlamaTokenizer,AutoModelForCausalLM,BitsAndBytesConfig
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from peft import LoraConfig,get_peft_model,TaskType, prepare_model_for_kbit_training

custom_template = "{user_tag} {instruction} {assistant_tag} {response}"


class DollyDataProvider:

    def __init__(self, user_tag, assistant_tag, tokenizer: PreTrainedTokenizer,data_file="./data/databricks-dolly-15k.jsonl", with_context=False):
        self.ds = load_dataset("json", data_files=data_file)
        self.user_tag = user_tag
        self.assistant_tag = assistant_tag
        self.tokenizer = tokenizer
        self.with_context = with_context

    def _prepare_format(self, examples):

        prompt_inputs = []
        for instruction,context ,response in zip(examples["instruction"], examples["context"], examples["response"]):

            if self.with_context:
                prompt_input = custom_template.format(user_tag=self.user_tag,
                                                      instruction=instruction+" "+context,
                                                      assistant_tag=self.assistant_tag,
                                                      response=response)
            else:
                prompt_input = custom_template.format(user_tag=self.user_tag,
                                                      instruction=instruction,
                                                      assistant_tag=self.assistant_tag,
                                                      response=response)
            prompt_inputs.append(prompt_input)

        batch = self.tokenizer(prompt_inputs, truncation=True, max_length=256, padding="max_length", return_tensors="pt")
        labels = batch["input_ids"].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        return batch

    def __call__(self, *args, **kwargs):
        return self.genrate_dataloader(*args)

    def genrate_dataloader(self,train_size=100,test_size=100,seed=42):
        double_ds = self.ds["train"].train_test_split(train_size=train_size, test_size=test_size, seed=seed, shuffle=True)
        train_loader = double_ds["train"].map(self._prepare_format,batched=True, remove_columns=['instruction', 'context', 'response', 'category'])
        test_loader = double_ds["test"].map(self._prepare_format,batched=True, remove_columns=['instruction', 'context', 'response', 'category'])


        return train_loader, test_loader


def peft_normal_func(model,tokenizer:PreTrainedTokenizer,train_dataloader,eval_data_loader,max_steps=10,lr=1e-4, eval_steps=10):
    training_args = TrainingArguments(
        output_dir="/home/project/ExtARE/output",
        learning_rate=lr,
        num_train_epochs=1,
        # max_steps=max_steps,
        eval_steps=eval_steps,
        evaluation_strategy="steps",
        per_device_train_batch_size=10,
        per_device_eval_batch_size=10,
        logging_steps=eval_steps
    )

    # data_collator = DataCollatorWithPadding(tokenizer=tokenizer,return_tensors="pt")
    # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    trainer = Trainer(
        args=training_args,
        model = model,
        train_dataset=train_dataloader,
        eval_dataset=eval_data_loader,
    )
    trainer.train()

def main():
    wandb.init(mode="disabled")
    user_tag = "[INST]"
    assistant_tag = "[/INST]"
    device="auto"

    tokenizer = LlamaTokenizer.from_pretrained("./models/Llama-2-7b-chat-hf")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    provider = DollyDataProvider(user_tag,assistant_tag,tokenizer)
    train_loader, test_loader = provider.genrate_dataloader()

    model = AutoModelForCausalLM.from_pretrained("./models/Llama-2-7b-chat-hf",
                                                 torch_dtype=torch.bfloat16,
                                                 device_map=device ,
                                                 quantization_config=BitsAndBytesConfig(load_in_8bit=True))
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8,
        target_modules=["q_proj","v_proj"],
    )
    model = prepare_model_for_kbit_training(model)
    peft_model = get_peft_model(model, peft_config)
    peft_model.print_trainable_parameters()


    peft_normal_func(peft_model,tokenizer,train_loader,test_loader)

if __name__ == '__main__':
    main()