import os
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from datasets import Dataset, load_dataset
import random
import numpy as np
import transformers
import wandb

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(42)


def get_data(data_file, tokenizer):
    def generate_prompt(data_point):
        result = {}
        instruction = data_point['en_input']
        input = data_point['en_other']
        if input == '':
            prompt = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n'
            input_text = prompt + '### Instruction:\n' + instruction + '\n\n### Response:\n'
        else:
            prompt = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n'
            input_text = prompt + '### Instruction:\n' + instruction + '\n\n### Input:\n' + input + '\n\n### Response:\n'
        target_text = data_point["en_output"] + tokenizer.eos_token
        result['en_text'] = input_text + target_text
        return result

    data = load_dataset("json", data_files=data_file)["train"]

    data = data.map(generate_prompt, num_proc=8)
    return data


model_ckpt = "../Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model_ckpt, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

train_dataset = get_data('../train/train.jsonl', tokenizer)


train_val = train_dataset.train_test_split(
    test_size=0.1, shuffle=True, seed=42
)

train_dataset = train_val["train"].shuffle(seed=42)
val_dataset = train_val["test"].shuffle(seed=42)

print(train_dataset)


model = AutoModelForCausalLM.from_pretrained(
    model_ckpt,
    device_map="auto",
    trust_remote_code=True
)

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

run = wandb.init(project='ocean')

output_dir = "../results"
training_args = TrainingArguments(
    report_to=["wandb"],
    output_dir=output_dir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    logging_steps=100,
    num_train_epochs=6,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    fp16=True,
    optim="adamw_torch",
    warmup_ratio=0.01,
    lr_scheduler_type="linear",
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    dataset_text_field="en_text",
    max_seq_length=200,
    tokenizer=tokenizer,
    args=training_args,
    packing=True,
    peft_config=peft_config,
)

trainer.train()

trainer.model.save_pretrained(os.path.join(output_dir, "final_checkpoint"))
