import os
import sys

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
    sys.path.append(project_root)

from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE, SIMPLE_SFT_CHAT_TEMPLATE
from transformers import AutoTokenizer
from tokenizers.processors import TemplateProcessing
import wandb
import torch
import argparse
import json

import src.constants as constants
# import src.setup_datasets as setup_datasets
import src.setup_datasets_clean as setup_datasets
import src.trainers as trainers

'''
Example Usage:
If using config file:
python scripts/train.py --config_file configs/train_dpo.yaml --run_name dpo_test_1
'''

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--run_name', dest='run_name', type=str, default='', help='Name of the run')

    parser.add_argument('--method', dest='method', type=str, default='', help='Training method')
    parser.add_argument('--model_name', dest='model_name', type=str, default='', help='Model name')
    parser.add_argument('--model_id', dest='model_id', type=str, default='', help='Model ID for LLM being trained')
    parser.add_argument('--dataset_name', dest='dataset_name', type=str, default='', help='Dataset names')
    parser.add_argument('--max_prompt_length', dest='max_prompt_length', type=int, default=512, help='Maximum prompt length')
    parser.add_argument('--dataset_dirs', dest='dataset_dirs', type=str, default='', help='Dataset directories')
    parser.add_argument('--root_save_dir', dest='root_save_dir', type=str, default='', help='Add root save directory')
    parser.add_argument('--local_wandb_dir', dest='local_wandb_dir', type=str, default='', help='Local directory for wandb logs')
    parser.add_argument('--trust_remote_code', dest='trust_remote_code', action='store_true', help='Trust remote code for model loading')
    parser.add_argument('--dataset_num_proc', dest='dataset_num_proc', type=int, default=1, help='Number of processes for dataset loading')
    parser.add_argument('--val_dataset_size', dest='val_dataset_size', type=int, default=100, help='Maximum size of validation dataset for SFT')

    # BitsAndBytesConfig parameters (common for DPO and SFT)
    parser.add_argument('--load_in_4bit', dest='load_in_4bit', action='store_true', default=True, help='Load model in 4-bit precision')
    parser.add_argument('--bnb_4bit_use_double_quant', dest='bnb_4bit_use_double_quant', action='store_true', default=True, help='Use double quantization')
    parser.add_argument('--bnb_4bit_quant_type', dest='bnb_4bit_quant_type', type=str, default='nf4', help='Quantization type')
    parser.add_argument('--bnb_4bit_compute_dtype', dest='bnb_4bit_compute_dtype', type=str, default='bfloat16', help='Compute dtype for BitsAndBytes')

    # AutoModelForCausalLM parameters (common for DPO and SFT)
    parser.add_argument('--device_map', dest='device_map', type=str, default='auto', help='Device map for model loading')
    parser.add_argument('--use_cache', dest='use_cache', action='store_true', default=True, help='Use cache for model')
    parser.add_argument('--attn_implementation', dest='attn_implementation', type=str, default='flash_attention_2', help='Attention implementation')
    parser.add_argument('--torch_dtype', dest='torch_dtype', type=str, default='bfloat16', help='PyTorch dtype')

    # LoraConfig parameters (common for DPO and SFT)
    parser.add_argument('--lora_alpha', dest='lora_alpha', type=int, default=128, help='LoRA alpha parameter')
    parser.add_argument('--lora_dropout', dest='lora_dropout', type=float, default=0.05, help='LoRA dropout rate')
    parser.add_argument('--lora_r', dest='lora_r', type=int, default=256, help='LoRA rank')
    parser.add_argument('--lora_bias', dest='lora_bias', type=str, default='none', help='LoRA bias setting')
    parser.add_argument('--lora_target_modules', dest='lora_target_modules', type=str, default='all-linear', help='LoRA target modules')
    parser.add_argument('--lora_task_type', dest='lora_task_type', type=str, default='CAUSAL_LM', help='LoRA task type')

    # Training parameters (common for DPO and SFT)
    parser.add_argument('--num_train_epochs', dest='num_train_epochs', type=int, default=1, help='Number of training epochs')
    parser.add_argument('--per_device_train_batch_size', dest='per_device_train_batch_size', type=int, default=12, help='Batch size per device during training')
    parser.add_argument('--per_device_eval_batch_size', dest='per_device_eval_batch_size', type=int, default=4, help='Batch size for evaluation')
    parser.add_argument('--gradient_accumulation_steps', dest='gradient_accumulation_steps', type=int, default=1, help='Number of steps before performing a backward/update pass')
    parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', action='store_true', default=True, help='Use gradient checkpointing to save memory')
    parser.add_argument('--save_strategy', dest='save_strategy', type=str, default='steps', help='Save strategy')
    parser.add_argument('--save_steps', dest='save_steps', type=int, default=50, help='When to save checkpoint')
    # parser.add_argument('--save_total_limit', dest='save_total_limit', type=int, default=2, help='Limit the total amount of checkpoints')
    parser.add_argument('--eval_strategy', dest='eval_strategy', type=str, default='steps', help='Evaluation strategy')
    parser.add_argument('--eval_steps', dest='eval_steps', type=int, default=20, help='When to evaluate')
    parser.add_argument('--max_steps', dest='max_steps', type=int, default=200, help='Number of training steps')
    parser.add_argument('--logging_strategy', dest='logging_strategy', type=str, default='steps', help='Logging strategy')
    parser.add_argument('--logging_steps', dest='logging_steps', type=int, default=5, help='Log every N steps')
    parser.add_argument('--bf16', dest='bf16', action='store_true', default=True, help='Use bfloat16 precision')
    parser.add_argument('--tf32', dest='tf32', action='store_true', default=True, help='Use tf32 precision')
    parser.add_argument('--push_to_hub', dest='push_to_hub', action='store_true', default=False, help='Push model to hub')
    parser.add_argument('--report_to', dest='report_to', type=str, default='wandb', help='Report metrics to')

    # DPO-specific parameters
    parser.add_argument('--optim', dest='optim', type=str, default='adamw_torch_fused', help='Optimizer')
    parser.add_argument('--learning_rate', dest='learning_rate', type=float, default=5e-5, help='Learning rate')
    parser.add_argument('--max_grad_norm', dest='max_grad_norm', type=float, default=0.3, help='Maximum gradient norm')
    parser.add_argument('--warmup_ratio', dest='warmup_ratio', type=float, default=0.1, help='Warmup ratio')
    parser.add_argument('--lr_scheduler_type', dest='lr_scheduler_type', type=str, default='cosine', help='Learning rate scheduler type')
    parser.add_argument('--beta', dest='beta', type=float, default=0.1, help='DPO beta parameter')
    parser.add_argument('--loss_type', dest='loss_type', type=str, default='sigmoid', help='DPO loss type')
    parser.add_argument('--prompt_length', dest='prompt_length', type=int, default=1024, help='DPO prompt length')
    parser.add_argument('--max_seq_length', dest='max_seq_length', type=int, default=1512, help='DPO maximum sequence length')

    # SFT-specific parameters
    parser.add_argument('--max_length', dest='max_length', type=int, default=512, help='SFT maximum sequence length')
    parser.add_argument('--dataset_text_field', dest='dataset_text_field', type=str, default=None, help='Text field in dataset to use for training')
    parser.add_argument('--completion_only_loss', dest='completion_only_loss', type=bool, default=False, help='Compute loss only on completion part of the sequence')
    parser.add_argument('--multi_turn', dest='multi_turn', type=bool, default=False, help='Use multi-turn dialogue for HH dataset (False for single-turn)')

    # PPO-specific parameters
    parser.add_argument('--value_model_id', dest='value_model_id', type=str, default='', help='Value model ID for PPO')
    parser.add_argument('--init_kl_coef', dest='init_kl_coef', type=float, default=0.2, help='Initial KL coefficient')
    parser.add_argument('--target_kl', dest='target_kl', type=float, default=6.0, help='Target KL divergence')
    parser.add_argument('--horizon', dest='horizon', type=int, default=10000, help='PPO horizon')
    parser.add_argument('--gamma', dest='gamma', type=float, default=1.0, help='Discount factor')
    parser.add_argument('--lam', dest='lam', type=float, default=0.95, help='Lambda for GAE')
    parser.add_argument('--cliprange', dest='cliprange', type=float, default=0.2, help='PPO clip range')
    parser.add_argument('--cliprange_value', dest='cliprange_value', type=float, default=0.2, help='Value function clip range')
    parser.add_argument('--vf_coef', dest='vf_coef', type=float, default=0.1, help='Value function coefficient')
    parser.add_argument('--max_response_length', dest='max_response_length', type=int, default=512, help='Maximum response length for PPO')
    
    # Sample generation and logging parameters
    parser.add_argument('--num_sample_generations', dest='num_sample_generations', type=int, default=10, 
                       help='Number of samples to generate for logging')
    
    # Custom reward function parameters
    parser.add_argument('--reward_model_name', dest='reward_model_name', type=str, default='', help='Model name for reward scoring')
    parser.add_argument('--reward_max_length', dest='reward_max_length', type=int, default=512, help='Maximum length for reward model')
    parser.add_argument('--query_response_separator', dest='query_response_separator', type=str, default='Assistant:', help='Separator between query and response')
    parser.add_argument('--use_api', dest='use_api', action='store_true', default=False, help='Use API for reward scoring')
    
    # Reward combiner parameters
    parser.add_argument('--reward_combiner_type', dest='reward_combiner_type', type=str, default='linear',
                       help='Type of reward combiner: linear, linear_regression, gradient_boosting, mlp')
    parser.add_argument('--reward_objectives', dest='reward_objectives', type=str, default='',
                       help='JSON list of objectives to evaluate, e.g., ["harmlessness", "clarity"]')
    parser.add_argument('--reward_manual_weights', dest='reward_manual_weights', type=str, default='',
                       help='JSON string of manual weights for linear combiner')
    parser.add_argument('--reward_manual_bias', dest='reward_manual_bias', type=float, default=0.0,
                       help='Bias term for linear combiner')
    parser.add_argument('--reward_combiner_path', dest='reward_combiner_path', type=str, default=None,
                       help='Path to a pre-fitted reward combiner to load')
    
    # MLP combiner specific
    parser.add_argument('--mlp_hidden_sizes', dest='mlp_hidden_sizes', type=str, default='[64,32]', 
                       help='JSON list of hidden layer sizes for MLP combiner')
    parser.add_argument('--mlp_dropout_rate', dest='mlp_dropout_rate', type=float, default=0.1, 
                       help='Dropout rate for MLP combiner')
    
    # Gradient boosting specific
    parser.add_argument('--gb_n_estimators', dest='gb_n_estimators', type=int, default=100, 
                       help='Number of estimators for gradient boosting combiner')
    parser.add_argument('--gb_max_depth', dest='gb_max_depth', type=int, default=3, 
                       help='Max depth for gradient boosting combiner')
    
    # GRPO-specific parameters (reusing existing parameters where possible)
    parser.add_argument('--temperature', dest='temperature', type=float, default=0.7, 
                       help='Temperature for generation (GRPO)')
    parser.add_argument('--num_samples', dest='num_samples', type=int, default=4, 
                       help='Number of samples per prompt (GRPO)')
    parser.add_argument('--rloo_k', dest='rloo_k', type=int, default=2, 
                       help='RLOO baseline parameter k (GRPO)')
    parser.add_argument('--reward_type', dest='reward_type', type=str, default='llm_function', 
                       help='Type of reward function: llm_function or reward_model')
    parser.add_argument('--use_quantization', dest='use_quantization', action='store_true', default=False, 
                       help='Use quantization for reward model')
    parser.add_argument('--dataset_type', dest='dataset_type', type=str, default='hh',
                       help='Dataset type for reward scoring: hh or tldr')
    parser.add_argument('--use_detailed_rubric', dest='use_detailed_rubric', action='store_true', default=True,
                       help='Use detailed rubric for LLM scoring')
    parser.add_argument('--max_concurrent', dest='max_concurrent', type=int, default=50,
                       help='Maximum concurrent API calls for async reward scoring')

    parser.add_argument('--config_file', dest='config_file', type=str, default=None, help='Path to config file (json or yaml)')

    args = parser.parse_args()

    if args.config_file:
        with open(args.config_file) as f:
            override = json.load(f) if args.config_file.endswith(".json") else \
                       __import__("yaml").safe_load(f)
        for k, v in override.items():
            # Handle boolean conversion explicitly for completion_only_loss
            if k == 'completion_only_loss' and isinstance(v, str):
                v = v.lower() in ('true', '1', 'yes')
            setattr(args, k, v)
    
    # Parse JSON string arguments
    if hasattr(args, 'reward_objectives') and isinstance(args.reward_objectives, str) and args.reward_objectives:
        try:
            args.reward_objectives = json.loads(args.reward_objectives)
        except json.JSONDecodeError:
            print(f"Warning: Could not parse reward_objectives JSON string: {args.reward_objectives}")
            args.reward_objectives = None
    elif hasattr(args, 'reward_objectives') and isinstance(args.reward_objectives, list):
        # Already a list from YAML config
        pass
    else:
        args.reward_objectives = None
    
    if hasattr(args, 'reward_manual_weights') and isinstance(args.reward_manual_weights, str) and args.reward_manual_weights:
        try:
            args.reward_manual_weights = json.loads(args.reward_manual_weights)
        except json.JSONDecodeError:
            print(f"Warning: Could not parse reward_manual_weights JSON string: {args.reward_manual_weights}")
            args.reward_manual_weights = None
    
    if hasattr(args, 'mlp_hidden_sizes') and isinstance(args.mlp_hidden_sizes, str) and args.mlp_hidden_sizes:
        try:
            args.mlp_hidden_sizes = json.loads(args.mlp_hidden_sizes)
        except json.JSONDecodeError:
            print(f"Warning: Could not parse mlp_hidden_sizes JSON string: {args.mlp_hidden_sizes}")
            args.mlp_hidden_sizes = [64, 32]
    
    print(args)
    return args

def init_tokenizer(args):
    if args.method == 'sft':
        tokenizer = AutoTokenizer.from_pretrained(args.model_name, add_eos_token=True, use_fast=False)
        if getattr(tokenizer, "pad_token", None) is None:
            tokenizer.add_special_tokens({'pad_token': '<|reserved_special_token_0|>'})
            # tokenizer.pad_token = tokenizer.eos_token
        if args.model_name == 'meta-llama/Llama-3.1-8B':
            bos = tokenizer.bos_token
            eos = tokenizer.eos_token
            tokenizer._tokenizer.post_processor = TemplateProcessing(
                single=f"{bos}:0 $A:0 {eos}:0",
                pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
                special_tokens=[
                    (f"{bos}", tokenizer.bos_token_id), 
                    (f"{eos}", tokenizer.eos_token_id)
                ],
            )
        elif 'Qwen' in args.model_name:
            tokenizer.add_special_tokens({'pad_token': '<|reserved_special_token_0|>'})
            tokenizer.pad_token = '<|reserved_special_token_0|>' # ensuring pad_token != eos_token for Qwen models
        tokenizer.padding_side = 'left' # to prevent errors with flash attention (adds padding to the left, so the rightmost tokens are the most recent)
        tokenizer.truncation_side = 'left' # to prevent cutting off last generation
    elif args.method == 'grpo':
        print("Initializing tokenizer for GRPO...")
        tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
        # tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, use_fast=False)
    if (tokenizer.chat_template is None) or ('Qwen' in args.model_name):
        print("Setting chat template for tokenizer...")
        # Custom template that includes EOS token after assistant messages
        CHAT_TEMPLATE_WITH_EOS = "{% for message in messages %}{% if message['role'] == 'assistant' %}{{ message['role'].capitalize() + ': ' + message['content'] + eos_token }}{% else %}{{ message['role'].capitalize() + ': ' + message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
        tokenizer.chat_template = CHAT_TEMPLATE_WITH_EOS
        # tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
        # tokenizer.chat_template = SIMPLE_SFT_CHAT_TEMPLATE
    return tokenizer

def train_model(args):
    run = wandb.init(project=constants.WANDB_PROJECT_NAME, 
                    group='train_{}'.format(args.method),
                    name=args.run_name,
                    tags=[args.method, args.model_name, args.dataset_name],
                    config=vars(args),
                    dir=args.local_wandb_dir)
    
    output_dir = os.path.join(args.root_save_dir, args.run_name)
    checkpoint_exists = True
    if not os.path.exists(output_dir):
        checkpoint_exists = False
        print("Creating output directory: {}".format(output_dir))
        os.makedirs(output_dir)

    tokenizer = init_tokenizer(args)
    train_dataset, val_dataset = setup_datasets.load_datasets(args, tokenizer, val_dataset_size=args.val_dataset_size)
    trainer = trainers.init_trainer(args, tokenizer, train_dataset, val_dataset)
    # train_complete = False
    # while not train_complete:
    #     try:
    #         print("Starting training...")
    #         if len(os.listdir(output_dir)) > 0:
    #             checkpoint_exists = True
    #         if checkpoint_exists:
    #             trainer.train(resume_from_checkpoint=True)
    #         else:
    #             trainer.train()
    #         train_complete = True
    #     except Exception as e:
    #         print(f"Training failed due to: {e}. Retrying...")

    trainer.train()
    trainer.save_model()
    print('Finished training model!')

def main():
    args = parse_args()
    train_model(args)
    print("Training complete.")

if __name__ == '__main__':
    main()