import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
from awdpo.utils import *
from awdpo.rewards import *
from awdpo.trainer import AWDPO_MLE_Trainer
from awdpo.train_config import AWDPOConfig
from baselines.awdpo_filtered_trainer import AWDPO_Filtered_Trainer
from datasets import load_from_disk

if __name__ == "__main__":
    # Model and tokenizer args   
    parser = argparse.ArgumentParser(description="Weighted Live DPO Training")

    parser.add_argument("--model_name", default="Qwen/Qwen2.5-0.5B")

    # AWDPO config arguments
    parser.add_argument("--output_dir", default="outputs/Qwen25_05B_AWDPO_Gsm8k")
    parser.add_argument("--run_name", default="Qwen25_05B_AWDPO_gsm8k_reasoner")
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--warmup_steps", type=int, default=100)
    parser.add_argument("--num_generations", type=int, default=5)
    parser.add_argument("--max_prompt_length", type=int, default=2000)
    parser.add_argument("--max_completion_length", type=int, default=500)
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--logging_steps", type=int, default=1)
    parser.add_argument("--save_steps", type=int, default=250)
    parser.add_argument("--max_steps", type=int, default=50)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--num_generated_samples_to_view", type=int, default=1000)
    parser.add_argument("--bf16", action="store_true", default=True)
    parser.add_argument("--per_device_train_batch_size", type=int, default=1)
    parser.add_argument("--use_vllm", action="store_true", help="Enable vLLM for inference")
    parser.add_argument("--vllm_device", default="cuda:0")
    parser.add_argument("--vllm_gpu_memory_utilization", type=float, default=0.3)
    parser.add_argument("--training_data_directory", type = str, default = "../data/gsm8k_fewshot_qwen25")
    parser.add_argument("--few_shot_column", type = str, default = "few_shot_examples_random")
    parser.add_argument("--train_column", type = str, default = "train")
    parser.add_argument("--use_reference_model", action="store_true", help = "whether to use a reference modl or not")
    parser.add_argument("--trainer_type", type=str, default='base')
    parser.add_argument("--use_lora",  action="store_true")
    parser.add_argument("--policy_reset", action="store_true")
    parser.add_argument("--use_advantage_scaling", action="store_true")
    parser.add_argument("--lora_rank", type = int, default = 64)
    parser.add_argument("--lora_alpha", type = int, default = 64)

    args = parser.parse_args()
    
    model_name = args.model_name
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        use_cache=False
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer.pad_token = tokenizer.eos_token
    
    config = AWDPOConfig(
        output_dir=args.output_dir,
        run_name=args.run_name,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_steps=args.warmup_steps,
        num_generations=args.num_generations,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,
        num_train_epochs=args.num_train_epochs,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        max_steps=args.max_steps,
        temperature=args.temperature,
        num_generated_samples_to_view=args.num_generated_samples_to_view,
        bf16=args.bf16,
        per_device_train_batch_size=args.per_device_train_batch_size,
        use_vllm=args.use_vllm,
        vllm_device=args.vllm_device,
        vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
        use_reference_model = args.use_reference_model,
        use_lora = args.use_lora,
        policy_reset = args.policy_reset,
        use_advantage_scaling = args.use_advantage_scaling,
        lora_rank = args.lora_rank,
        lora_alpha = args.lora_alpha
    )
    
    dataset = load_from_disk(args.training_data_directory)
    
    reward_functions = [reasoning_reward, accuracy_reward, soft_format_reward, 
                        strict_format_reward, int_reward, xmlcount_reward, 
                        proper_termination_reward, clean_answer_termination_reward, coherency_reward]
    
    if args.trainer_type == 'filtered':
        print("Using AWDPO for only accurate responses")
        trainer = AWDPO_Filtered_Trainer(model, tokenizer, reward_functions, config, dataset[args.train_column], args.few_shot_column, SYSTEM_PROMPT)
    else:
        print("Using AWDPO with MLE term")
        trainer = AWDPO_MLE_Trainer(model, tokenizer, reward_functions, config, dataset[args.train_column], args.few_shot_column, SYSTEM_PROMPT)
    
    trainer.train()
    

    