import argparse
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

def parse_arguments():
    parser = argparse.ArgumentParser(description='Fine-tune LLaMA model with LoRA')
    parser.add_argument('--output_dir', type=str, required=True, help='Output directory for the fine-tuned model')
    parser.add_argument("--model", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model name or path.")
    return parser.parse_args()

# Define the dataset class
class AlpacaDataset(Dataset):
    def __init__(self, dataset, tokenizer):
        self.dataset = dataset
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        prompt = self.dataset[idx]["instruction"]
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").squeeze()
        return {"input_ids": input_ids}

def main(args):
    # Load the dataset
    # dataset = load_dataset("tatsu-lab/alpaca")
    dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")

    # Load the pre-trained model and tokenizer
    model_name = args.model
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # Define the LoRA configuration
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # Add LoRA adapters to the model
    model = get_peft_model(model, lora_config)

    # Create the dataset
    train_dataset = AlpacaDataset(dataset["train"], tokenizer)

    # Define the training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=64,  # Reduce batch size
        learning_rate=1e-4,
        warmup_steps=500,
        logging_steps=50,
        save_steps=500,
        save_total_limit=3,
        fp16=True,
        gradient_accumulation_steps=1,  # Increase gradient accumulation steps
        remove_unused_columns=False,
        deepspeed="deepspeed.json",  # Path to DeepSpeed config file
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Create the trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )

    # Fine-tune the model
    trainer.train()

    # Save the fine-tuned model
    model.save_pretrained(args.output_dir)

if __name__ == "__main__":
    # Parse command-line arguments
    args = parse_arguments()
    main(args)