import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    HfArgumentParser
)
from datasets import load_dataset
import os
from dataclasses import dataclass, field
from typing import Optional, List

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

@dataclass
class ModelArguments:
    model_name: str = field(default="/data/models_ckpt/zephyr-7b-beta")

@dataclass
class DataArguments:
    train_data: str = field(default="/llm_unlearning/wmdp/data/processed_fictional_knowledge.json")

@dataclass
class TrainingArguments(TrainingArguments):
    output_dir: str = field(default="/data/models_ckpt/trained_ckpt/unlearning/clm_zephyr_fictional")
    overwrite_output_dir: bool = field(default=True)
    num_train_epochs: int = field(default=7)
    per_device_train_batch_size: int = field(default=2)
    gradient_accumulation_steps: int = field(default=2)
    learning_rate: float = field(default=2e-5)
    # warmup_steps: int = field(default=500)
    logging_steps: int = field(default=1)
    max_length: int = field(default=4096)
    seed: int = field(default=42)
    
    fp16: bool = field(default=False)
    bf16: bool = field(default=True)
    gradient_checkpointing: bool = field(default=True)
    
    logging_dir: str = field(default=None)
    log_level: str = field(default="info")
    logging_strategy: str = field(default="steps")
    
    save_total_limit: int = field(default=1)
    remove_unused_columns: bool = field(default=False)
    
    report_to: List[str] = field(default_factory=lambda: ["wandb"])
    run_name: Optional[str] = field(default=None)
    
    dataloader_num_workers: int = field(default=8)
    group_by_length: bool = field(default=True)
    lr_scheduler_type: str = field(default="linear")
    optim: str = field(default="paged_adamw_32bit")
    max_grad_norm: float = field(default=1.0)
    weight_decay: float = field(default=0.1)
    adam_beta1: float = field(default=0.9)
    adam_beta2: float = field(default=0.95)
    
    local_rank: int = field(default=-1)
    deepspeed: Optional[str] = field(default=None)

    def __post_init__(self):
        super().__post_init__()
        if self.logging_dir is None:
            self.logging_dir = os.path.join(self.output_dir, "logs")
        if self.run_name is None:
            self.run_name = f"clm-{os.path.basename(self.output_dir)}-{self.seed}"

def get_model_and_tokenizer(model_args: ModelArguments, training_args: TrainingArguments):
    print("Loading tokenizer and model...")
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
        
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        use_cache=False,
        # attn_implementation="flash_attention_2"
    )
    return model, tokenizer

def prepare_dataset(tokenizer, training_args: TrainingArguments, data_args: DataArguments):
    def tokenize_function(examples):
        return tokenizer(
            examples["train_context"],
            truncation=True,
            max_length=training_args.max_length,
            padding=False,
            return_special_tokens_mask=True
        )

    print("Loading dataset...")
    dataset = load_dataset("json", data_files=data_args.train_data, split='train')
    
    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=4,
        remove_columns=dataset.column_names,
    )
    
    return tokenized_dataset

def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    
    # Set random seed
    torch.manual_seed(training_args.seed)
    torch.cuda.manual_seed_all(training_args.seed)
    
    # Create output directory
    os.makedirs(training_args.output_dir, exist_ok=True)
    
    # Load model and tokenizer
    model, tokenizer = get_model_and_tokenizer(model_args, training_args)
    
    # Load and preprocess dataset
    tokenized_dataset = prepare_dataset(tokenizer, training_args, data_args)
    
    # Data collator for language modeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )
    
    # Train
    print("Starting continued pre-training...")
    trainer.train()
    trainer.save_model()
    tokenizer.save_pretrained(training_args.output_dir)

    # Save final model
    print("Saving model...")

if __name__ == "__main__":
    main()
