import copy 
import os 
import sys 
from datetime import timedelta 

import torch 
import torch .distributed as torch_distributed 
from torch .distributed .fsdp .fully_sharded_data_parallel import CPUOffload 
import torch .optim as optim 
from torch .distributed .fsdp import FullyShardedDataParallel as FSDP 
from torch .optim .lr_scheduler import StepLR 
import wandb 

from llama_recipes .policies import AnyPrecisionAdamW ,apply_fsdp_checkpointing 
from llama_recipes .utils .train_utils import (
clear_gpu_cache ,
freeze_transformer_layers ,
get_policies ,
print_model_size ,
setup_environ_flags ,
train ,
)
from llama_recipes .optimizer import WarmupCosineAnnealingLR 
from llama_recipes .utils .random import set_seed 
from llama_recipes .utils .distributed import (
print_rank_0 ,
is_rank_0 ,
set_mpi_env ,
get_rank ,
get_local_rank ,
)
from llama_recipes .get_models import get_model 
from llama_recipes .utils .checkpoint import (
load_model_state_dict ,
load_optimizer_state_dict ,
load_dist_model_state_dict ,
load_dist_optimizer_state_dict ,
load_scheduler_state_dict ,
load_rng_state_dict ,
get_latest_iteration ,
)

from llama_recipes .arguments import parse_args 
from llama_recipes .get_fsdp import get_sharding_strategy 
from llama_recipes .utils .precision import preserve_fp32_buffers 
from megatron_lm .megatron .global_vars import set_global_variables 

import datetime 

current_path :str =os .getcwd ()
sys .path .append (f"{current_path }/llama-recipes/src/")


def log (msg :str )->None :
    ts =datetime .datetime .now ().strftime ("%Y-%m-%d %H:%M:%S")
    r =os .environ .get ("RANK","?")
    try :
        r =str (get_rank ())
    except Exception :
        pass 
    print (f"[{ts }][rank={r }] {msg }",flush =True )



def main ()->None :
    log ("PROGRAM START")


    log ("parse_args START")
    args =parse_args ()
    log ("parse_args DONE")

    is_pretraining =not (args .instruction_tuning or args .direct_preference_optimization )

    log ("set_global_variables START")
    set_global_variables (args =args ,build_tokenizer =is_pretraining )
    log ("set_global_variables DONE")


    log (f"set_seed START (seed={args .seed })")
    set_seed (seed =args .seed )
    log ("set_seed DONE")


    if args .use_mpi :
        log ("set_mpi_env START")
        set_mpi_env ()
        log ("set_mpi_env DONE")

    log ("read RANK/WORLD_SIZE from env START")
    rank =int (os .environ ["RANK"])
    world_size =int (os .environ ["WORLD_SIZE"])
    args .rank =rank 
    args .world_size =world_size 
    args .gradient_accumulation_steps =args .global_batch_size //(args .micro_batch_size *world_size )
    assert args .gradient_accumulation_steps >=1 
    log (f"read RANK/WORLD_SIZE DONE (rank={rank }, world_size={world_size }, grad_accum_steps={args .gradient_accumulation_steps })")

    timeout =timedelta (minutes =args .distributed_timeout_minutes )
    log ("init_process_group START")
    torch_distributed .init_process_group (
    backend ="nccl",world_size =world_size ,rank =rank ,timeout =timeout ,device_id =torch .device ('cuda',rank )
    )
    log ("init_process_group DONE")


    if args .wandb_name is not None and is_rank_0 ():
        log ("wandb.init START")
        now =datetime .datetime .now ().strftime ("%Y-%m-%d-%H-%M-%S")
        wandb_setting :dict ={
        "entity":args .wandb_entity ,
        "project":args .wandb_project ,
        "name":args .wandb_name ,
        "config":vars (args ),
        }
        wandb .require ("core")
        wandb .init (**wandb_setting )
        log ("wandb.init DONE")

    if torch_distributed .is_initialized ():
        log ("set_device/clear_gpu_cache/setup_environ_flags START")
        torch .cuda .set_device (get_local_rank ())
        clear_gpu_cache (get_local_rank ())
        setup_environ_flags (get_rank ())
        log ("set_device/clear_gpu_cache/setup_environ_flags DONE")

    log ("get_latest_iteration START")
    iteration :int =get_latest_iteration (args .load )
    args .iteration =iteration 
    log (f"get_latest_iteration DONE (iteration={iteration })")

    log ("barrier before RNG load START")
    torch_distributed .barrier ()
    log ("barrier before RNG load DONE")


    if args .load :
        log ("load_rng_state_dict START")
        load_rng_state_dict (args .load )
        log ("load_rng_state_dict DONE")
        log ("barrier after RNG load START")
        torch_distributed .barrier ()
        log ("barrier after RNG load DONE")

    use_cache =False 
    log (f"get_model START (base_model={args .base_model })")
    model =get_model (
    model_name =args .base_model ,use_cache =use_cache 
    )
    log ("get_model DONE")

    if args .direct_preference_optimization :
        log ("reference_model deepcopy START")
        reference_model =copy .deepcopy (model )
        for param in reference_model .parameters ():
            param .requires_grad =False 
        log ("reference_model deepcopy DONE")

    if args .load :
        if args .use_dist_ckpt :
            log ("load_dist_model_state_dict START")
            load_dist_model_state_dict (model ,args .load )
            log ("load_dist_model_state_dict DONE")
        else :
            log ("load_model_state_dict START")
            load_model_state_dict (model ,args .load )
            log ("load_model_state_dict DONE")

    log ("print_model_size START")
    print_model_size (model ,args .base_model ,rank )
    log ("print_model_size DONE")



    log ("dtype cast START")
    with preserve_fp32_buffers (model ):
        if args .bf16 :
            model .to (torch .bfloat16 )
        elif args .fp16 :
            model .to (torch .float16 )
    log ("dtype cast DONE")

    if args .direct_preference_optimization :
        log ("reference_model dtype cast START")
        with preserve_fp32_buffers (reference_model ):
            if args .bf16 :
                reference_model .to (torch .bfloat16 )
            elif args .fp16 :
                reference_model .to (torch .float16 )
        log ("reference_model dtype cast DONE")

    if args .use_freeze_layers :
        log ("freeze_transformer_layers START")
        print_rank_0 ("NOTE: freeze transformer layers")
        freeze_transformer_layers (model =model ,layer_ranges =args .freeze_layers )
        log ("freeze_transformer_layers DONE")


    log ("get_policies START")
    mixed_precision_policy ,wrapping_policy =get_policies (
    rank =get_rank (),
    model_name =args .base_model ,
    )
    log ("get_policies DONE")

    from torch .distributed ._tensor .device_mesh import init_device_mesh 
    log (f"init_device_mesh START (world_size={world_size })")
    device_mesh =init_device_mesh (device_type ="cuda",mesh_shape =(world_size ,))
    log ("init_device_mesh DONE")

    log ("FSDP(model) wrap START")
    model =FSDP (
    model ,
    auto_wrap_policy =wrapping_policy ,
    cpu_offload =CPUOffload (offload_params =True )if args .fsdp_cpu_offload else None ,
    mixed_precision =mixed_precision_policy ,
    sharding_strategy =get_sharding_strategy (),
    device_id =torch .cuda .current_device (),
    limit_all_gathers =True ,
    sync_module_states =args .low_cpu_fsdp ,
    param_init_fn =lambda module :module .to_empty (
    device =torch .cuda .current_device (),recurse =False ,
    )
    if args .low_cpu_fsdp and rank !=0 
    else None ,
    device_mesh =device_mesh ,
    )
    log ("FSDP(model) wrap DONE")

    if args .fsdp_activation_checkpointing :
        log ("apply_fsdp_checkpointing START")
        apply_fsdp_checkpointing (model =model ,model_name =args .base_model )
        log ("apply_fsdp_checkpointing DONE")

    if args .direct_preference_optimization :
        log ("FSDP(reference_model) wrap START")
        reference_model =FSDP (
        reference_model ,
        auto_wrap_policy =wrapping_policy ,
        cpu_offload =CPUOffload (offload_params =True )if args .fsdp_cpu_offload else None ,
        mixed_precision =mixed_precision_policy ,
        sharding_strategy =get_sharding_strategy (),
        device_id =torch .cuda .current_device (),
        limit_all_gathers =True ,
        sync_module_states =args .low_cpu_fsdp ,
        param_init_fn =lambda module :module .to_empty (
        device =torch .cuda .current_device (),recurse =False ,
        )
        if args .low_cpu_fsdp and rank !=0 
        else None ,
        )
        log ("FSDP(reference_model) wrap DONE")

    if not args .instruction_tuning and not args .direct_preference_optimization :
        args .continual_pretraining =True 
    log (f"training_mode resolved (continual_pretraining={getattr (args ,'continual_pretraining',False )}, "
    f"instruction_tuning={args .instruction_tuning }, dpo={args .direct_preference_optimization })")

    dpo_loss_fn =None 
    if args .continual_pretraining :
        log ("build_train_valid_test_datasets START")
        from llama_recipes .datasets .pretrain_dataset import build_train_valid_test_datasets 
        from megatron_lm .megatron .data .data_samplers import build_pretraining_data_loader 

        train_dataset ,validation_dataset ,test_dataset =build_train_valid_test_datasets ()
        log ("build_train_valid_test_datasets DONE")

        args .consumed_train_samples =args .global_batch_size *args .iteration 
        args .consumed_valid_samples =args .global_batch_size *(
        args .iteration //args .eval_interval )*args .eval_iters 
        log (f"consumed_samples set (train={args .consumed_train_samples }, valid={args .consumed_valid_samples })")

        log ("build_pretraining_data_loader(train) START")
        train_dataloader =build_pretraining_data_loader (
        dataset =train_dataset ,
        consumed_samples =args .consumed_train_samples ,
        )
        log ("build_pretraining_data_loader(train) DONE")

        log ("build_pretraining_data_loader(valid) START")
        validation_dataloader =build_pretraining_data_loader (
        dataset =validation_dataset ,
        consumed_samples =args .consumed_valid_samples ,
        )
        log ("build_pretraining_data_loader(valid) DONE")

    else :
        from transformers import AutoTokenizer 
        from llama_recipes .utils .instruction_tuning import get_instruction_tuning_dataloader 
        from llama_recipes .utils .dpo_dataset import get_dpo_dataloader 

        log ("AutoTokenizer.from_pretrained START")
        hf_tokenizer =AutoTokenizer .from_pretrained (
        pretrained_model_name_or_path =args .hf_transformer_model_dir 
        )
        log ("AutoTokenizer.from_pretrained DONE")

        if args .instruction_tuning :
            log ("get_instruction_tuning_dataloader(train) START")
            train_dataloader =get_instruction_tuning_dataloader (
            tokenizer =hf_tokenizer ,
            data_path =args .instruction_train_data_path ,
            train =True ,
            )
            log ("get_instruction_tuning_dataloader(train) DONE")

            log ("get_instruction_tuning_dataloader(valid) START")
            validation_dataloader =get_instruction_tuning_dataloader (
            tokenizer =hf_tokenizer ,
            data_path =args .instruction_valid_data_path ,
            )
            log ("get_instruction_tuning_dataloader(valid) DONE")

            args .train_iters =args .instruction_dataset_size //args .global_batch_size *args .epoch 
            args .lr_decay_iters =args .train_iters 
            args .lr_warmup_iters =args .lr_decay_iters //10 
            args .save_sampler_state =True 
            if rank ==0 :
                log ("update_iter_info START")
                from llama_recipes .utils .wandb_utils import update_iter_info 
                update_iter_info ()
                log ("update_iter_info DONE")

        elif args .direct_preference_optimization :
            from llama_recipes .utils .dpo_loss import DPOLoss 

            log ("DPOLoss ctor START")
            dpo_loss_fn =DPOLoss (
            beta =args .dpo_beta ,
            label_smoothing =args .dpo_label_smoothing ,
            )
            log ("DPOLoss ctor DONE")

            log ("get_dpo_dataloader(train) START")
            train_dataloader =get_dpo_dataloader (
            tokenizer =hf_tokenizer ,
            data_path =args .dpo_train_data_path ,
            train =True 
            )
            log ("get_dpo_dataloader(train) DONE")

            log ("get_dpo_dataloader(valid) START")
            validation_dataloader =get_dpo_dataloader (
            tokenizer =hf_tokenizer ,
            data_path =args .dpo_valid_data_path 
            )
            log ("get_dpo_dataloader(valid) DONE")

            args .train_iters =args .dpo_dataset_size //args .global_batch_size *args .epoch 
            args .lr_decay_iters =args .train_iters 
            args .lr_warmup_iters =args .lr_decay_iters //10 
            args .save_sampler_state =True 
            if rank ==0 :
                log ("update_iter_info START")
                from llama_recipes .utils .wandb_utils import update_iter_info 
                update_iter_info ()
                log ("update_iter_info DONE")
        else :
            raise ValueError ("unknown training mode")

    log ("optimizer ctor START")
    optimizer =optim .AdamW (
    model .parameters (),
    lr =args .lr ,
    betas =(args .adam_beta1 ,args .adam_beta2 ),
    eps =args .adam_eps ,
    weight_decay =args .weight_decay ,
    )
    log ("optimizer ctor DONE")

    if args .load :
        if args .use_dist_ckpt :
            log ("load_dist_optimizer_state_dict START")
            load_dist_optimizer_state_dict (model =model ,optimizer =optimizer ,path =args .load )
            log ("load_dist_optimizer_state_dict DONE")
        else :
            log ("load_optimizer_state_dict START")
            load_optimizer_state_dict (model =model ,optimizer =optimizer ,path =args .load )
            log ("load_optimizer_state_dict DONE")

    log (f"scheduler ctor START (style={args .lr_decay_style })")
    if args .lr_decay_style =="cosine":
        scheduler =WarmupCosineAnnealingLR (
        optimizer =optimizer ,
        warmup_iterations =args .lr_warmup_iters ,
        decay_iterations =args .lr_decay_iters ,
        max_iterations =args .train_iters ,
        eta_min =args .min_lr ,
        )
    else :
        scheduler =StepLR (optimizer ,step_size =1 ,gamma =0.85 )
    log ("scheduler ctor DONE")

    if args .load :
        log ("load_scheduler_state_dict START")
        load_scheduler_state_dict (scheduler ,args .load )
        log ("load_scheduler_state_dict DONE")

    log ("train() START")
    train (
    model =model ,
    train_dataloader =train_dataloader ,
    eval_dataloader =validation_dataloader ,
    optimizer =optimizer ,
    lr_scheduler =scheduler ,
    gradient_accumulation_steps =args .gradient_accumulation_steps ,
    local_rank =get_local_rank (),
    rank =get_rank (),
    dpo_loss_fn =dpo_loss_fn ,
    reference_model =reference_model if args .direct_preference_optimization else None ,
    )
    log ("train() DONE")


if __name__ =="__main__":
    main ()
