import os
os.environ["ACCELERATE_USE_APEX"] = "False"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
import json
import datetime
from dataclasses import dataclass, field, asdict
from typing import Optional
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from peft import LoraConfig
from trl import DPOTrainer, DPOConfig
from transformers import HfArgumentParser
import torch.distributed as dist
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

@dataclass
class ScriptArguments:
    model_name_or_path: str = field(default="")
    dataset_path: str = field(default="")
    output_dir: str = field(default=f"")
    max_seq_length: int = field(default=2048)
    per_device_train_batch_size: int = field(default=4)
    gradient_accumulation_steps: int = field(default=1)
    learning_rate: float = field(default=5e-6)
    num_train_epochs: int = field(default=3)
    fp16: bool = field(default=True)
    bf16: bool = field(default=False)
    logging_steps: int = field(default=10)
    save_steps: int = field(default=100)  
    save_total_limit: int = field(default=16)
    lora_r: int = field(default=64)
    lora_alpha: int = field(default=16)
    lora_dropout: float = field(default=0.05)
    gradient_checkpointing: bool = field(default=True)
    deepspeed_config: Optional[str] = field(default=None)
    warmup_ratio: float = field(default=0.1)
    weight_decay: float = field(default=0.01)
    optim: str = field(default="adamw_torch")
    lr_scheduler_type: str = field(default="cosine")
    beta: float = field(default=0.01)  
    max_prompt_length: int = field(default=512)
    max_length: int = field(default=1024)

def is_main_process():
    if dist.is_initialized():
        return dist.get_rank() == 0
    return int(os.environ.get("LOCAL_RANK", 0)) == 0

class TrainingLogger:
    def __init__(self, output_dir):
        self.is_main = is_main_process()
        self.output_dir = output_dir
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        self.log_file = os.path.join(output_dir, f"training_config_{timestamp}.json")
        self.txt_log_file = os.path.join(output_dir, f"training_log_{timestamp}.txt")
        if self.is_main:
            os.makedirs(output_dir, exist_ok=True)
            self.config_data = {
                "timestamp": timestamp,
                "start_time": datetime.datetime.now().isoformat(),
            }
            self.txt_file = open(self.txt_log_file, 'w', encoding='utf-8')
        else:
            self.config_data = {}
            self.txt_file = None

    def log(self, message, print_to_console=True):
        if self.is_main:
            if print_to_console:
                print(message)
            if self.txt_file:
                self.txt_file.write(message + "\n")
                self.txt_file.flush()
        
    def update_config(self, key, value):
        if self.is_main:
            self.config_data[key] = value
        
    def save_config(self):
        if self.is_main:
            with open(self.log_file, 'w', encoding='utf-8') as f:
                json.dump(self.config_data, f, indent=2, ensure_ascii=False)
            
    def close(self):
        if self.is_main and self.txt_file:
            self.txt_file.close()

def print_trainable_parameters(model, logger=None):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    message = (
        f"trainable params: {trainable_params:,} || "
        f"all params: {all_param:,} || "
        f"trainable%: {100 * trainable_params / all_param:.2f}"
    )
    if logger:
        logger.log(message)
    elif is_main_process():
        print(message)
    return {
        "trainable_params": trainable_params,
        "all_params": all_param,
        "trainable_percentage": 100 * trainable_params / all_param
    }

def load_and_format_dataset(dataset_path, tokenizer):
    with open(dataset_path, "r", encoding="utf-8") as f:
        raw_data = json.load(f)
    def format_example(example):
        prompt = example["prompt"].strip()
        chosen = example["chosen"].strip()
        rejected = example["rejected"].strip()
        system_message = "You are a helpful assistant."
        prompt_messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": prompt}
        ]
        if hasattr(tokenizer, 'apply_chat_template'):
            formatted_prompt = tokenizer.apply_chat_template(
                prompt_messages, 
                tokenize=False, 
                add_generation_prompt=True
            )
        else:
            formatted_prompt = (
                f"{tokenizer.bos_token}{system_message}"
                f"<｜User｜>{prompt}"
                f"<｜Assistant｜><think>\n"
            )
        return {
            "prompt": formatted_prompt,
            "chosen": chosen + "<｜end▁of▁sentence｜>",
            "rejected": rejected + "<｜end▁of▁sentence｜>",
            "source": example.get("source", "unknown")
        }
    formatted_data = [format_example(item) for item in raw_data]
    dataset = Dataset.from_list(formatted_data)
    return dataset

def main():
    parser = HfArgumentParser(ScriptArguments)
    args = parser.parse_args_into_dataclasses()[0]
    logger = TrainingLogger(args.output_dir)
    logger.update_config("script_arguments", asdict(args))
    logger.update_config("system_info", {
        "pytorch_version": torch.__version__,
        "cuda_available": torch.cuda.is_available(),
        "cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
        "gpu_count": torch.cuda.device_count(),
        "gpu_names": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] if torch.cuda.is_available() else []
    })
    logger.log("="*50)
    logger.log(f"DPO Training started at: {datetime.datetime.now()}")
    logger.log("="*50)
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        padding_side="left",  
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    logger.log("Loading dataset...")
    dataset = load_and_format_dataset(args.dataset_path, tokenizer)
    logger.update_config("dataset_size", len(dataset))
    
    if is_main_process() and len(dataset) > 0:
        logger.log("\n=== Dataset Sample ===")
        sample = dataset[0]
        logger.log(f"Prompt: {sample['prompt'][:100]}...")
        logger.log(f"Chosen: {sample['chosen'][:100]}...")
        logger.log(f"Rejected: {sample['rejected'][:100]}...")
        logger.log("="*50)

    logger.log("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch.float16 if args.fp16 else torch.float32,
        use_cache=False if args.gradient_checkpointing else True,
    )
    logger.log("Loading reference model...")

    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()

    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )

    logger.update_config("lora_config", {
        "r": lora_config.r,
        "lora_alpha": lora_config.lora_alpha,
        "lora_dropout": lora_config.lora_dropout,
        "target_modules": list(lora_config.target_modules) if isinstance(lora_config.target_modules, set) else lora_config.target_modules,
        "bias": lora_config.bias
    })

    training_args = DPOConfig(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        num_train_epochs=args.num_train_epochs,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        save_total_limit=args.save_total_limit,
        bf16=args.bf16,
        fp16=args.fp16,
        gradient_checkpointing=args.gradient_checkpointing,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        report_to=["tensorboard"] if is_main_process() else [],
        deepspeed=args.deepspeed_config,
        warmup_ratio=args.warmup_ratio,
        weight_decay=args.weight_decay,
        optim=args.optim,
        lr_scheduler_type=args.lr_scheduler_type,
        seed=42,
        remove_unused_columns=False,
        logging_dir=os.path.join(args.output_dir, "logs"),
        beta=args.beta,
        max_prompt_length=args.max_prompt_length,
        max_length=args.max_length,
        generate_during_eval=False,
        precompute_ref_log_probs=True,
    )

    trainer = DPOTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        peft_config=lora_config,
    )
    
    logger.log("="*50)
    logger.log("Model Information:")
    logger.log(f"Model: {args.model_name_or_path}")
    logger.log(f"LoRA rank: {args.lora_r}")
    logger.log(f"LoRA alpha: {args.lora_alpha}")
    logger.log(f"LoRA target modules: {list(lora_config.target_modules) if isinstance(lora_config.target_modules, set) else lora_config.target_modules}")
    logger.log(f"DPO beta: {args.beta}")
    logger.log("="*50)
    logger.log("\nTrainable Parameters:")
    param_info = print_trainable_parameters(trainer.model, logger)
    logger.update_config("parameter_info", param_info)
    logger.log("="*50)
    logger.log("\nTraining Configuration:")
    logger.log(f"Batch size per device: {args.per_device_train_batch_size}")
    logger.log(f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
    total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * torch.cuda.device_count()
    logger.log(f"Total batch size: {total_batch_size}")
    logger.log(f"Number of epochs: {args.num_train_epochs}")
    logger.log(f"Learning rate: {args.learning_rate}")
    logger.log(f"Dataset size: {len(dataset)}")
    num_training_steps = len(dataset) // total_batch_size * args.num_train_epochs
    logger.log(f"Total training steps: {num_training_steps}")
    logger.log("="*50 + "\n")
    logger.update_config("training_info", {
        "total_batch_size": total_batch_size,
        "num_training_steps": num_training_steps,
        "warmup_steps": int(num_training_steps * args.warmup_ratio),
        "dpo_beta": args.beta
    })
    logger.save_config()
    logger.log("Starting DPO training...")
    train_result = trainer.train()
    logger.update_config("training_results", {
        "train_loss": train_result.metrics.get("train_loss", None),
        "train_runtime": train_result.metrics.get("train_runtime", None),
        "train_samples_per_second": train_result.metrics.get("train_samples_per_second", None),
        "train_steps_per_second": train_result.metrics.get("train_steps_per_second", None),
        "total_flos": train_result.metrics.get("total_flos", None),
    })
    if is_main_process():
        final_save_path = os.path.join(args.output_dir, "final_checkpoint")
        trainer.save_model(final_save_path)
        tokenizer.save_pretrained(final_save_path)
        logger.log(f"\nTraining completed! Model saved to: {final_save_path}")
    logger.update_config("end_time", datetime.datetime.now().isoformat())
    logger.log(f"\nTraining ended at: {datetime.datetime.now()}")
    logger.save_config()
    logger.log(f"\nConfiguration saved to: {logger.log_file}")
    logger.log(f"Training log saved to: {logger.txt_log_file}")
    logger.close()

if __name__ == "__main__":
    main()