from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling, Trainer
from datasets import load_from_disk
from peft import LoraConfig,get_peft_model
import torch
import argparse
import os
import numpy as np
import random
import json

seed = 23
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

parser = argparse.ArgumentParser(description='Trainer Script')
parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-1B', help='Model name')
parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs')
parser.add_argument('--cluster', type=int, default=0, help='Cluster number')
parser.add_argument('--batchSize', type=int, default=4, help='Batch size')
parser.add_argument('--gradAcc', type=int, default=1, help='Gradient Accumulation Steps')
parser.add_argument('--weight_decay', type=float, default=0.01, help='Gradient Accumulation Steps')
parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate')
parser.add_argument('--logToWandb', action='store_true', help='Log to wandb')
parser.add_argument('--config', type=str, default='/path/to/config', help='Config Path')
args = parser.parse_args()
config = json.load(open(args.config))
print(args, flush=True)

MODEL = args.model
EPOCHS = args.epochs

model = AutoModelForCausalLM.from_pretrained(MODEL,  attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"  

def tokenize_function(examples):
    return tokenizer(examples[config["dataset_data_column_name"]], return_special_tokens_mask=True, truncation=True, max_length=1024)


dataset = load_from_disk(os.path.join(config["working_dir"], "dataset_split", str(args.cluster)))

dataset = dataset.select_columns([config["dataset_data_column_name"]]).map(tokenize_function, batched=True, num_proc=1, remove_columns=[config["dataset_data_column_name"]])

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

lora_config = LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=[
            "k_proj",
            "q_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "down_proj",
            "up_proj",
        ],
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )

print(model)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

training_arguments = TrainingArguments(
    seed = 23,
    eval_on_start=False,
    per_device_train_batch_size=args.batchSize,
    per_device_eval_batch_size=args.batchSize,
    gradient_accumulation_steps=args.gradAcc,
    save_total_limit=0,
    optim="paged_adamw_32bit",
    num_train_epochs=EPOCHS,
    eval_strategy="no",
    save_strategy="no",
    logging_strategy="epoch",
    learning_rate=args.lr,
    fp16=False,
    bf16=True,
    bf16_full_eval=True,
    weight_decay=args.weight_decay,
    report_to="wandb" if args.logToWandb else "none",
    save_only_model=True,
)

trainer = Trainer(
    model=model,
    train_dataset=dataset,
    data_collator=data_collator,
    args=training_arguments,
)

trainer.train()

trainer.save_model(os.path.join(config["working_dir"], "adapters", str(args.cluster)))