from transformers import Trainer, TrainerCallback
from trl import SFTTrainer
import os
from torch.distributed.fsdp import (
        FullyShardedDataParallel as FSDP,
        MixedPrecision,
        BackwardPrefetch,
        ShardingStrategy,
        FullStateDictConfig,
        StateDictType,
    )

class FSDPTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def _save_checkpoint(self, model, trial, metrics=None):
        if self.is_fsdp_enabled:
            self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

        super()._save_checkpoint(model, trial, metrics=metrics)

        
class FSDPSFTTrainer(FSDPTrainer, SFTTrainer):
    def __init__(self, *args, **kwargs):
        SFTTrainer.__init__(self, *args, **kwargs)
        FSDPTrainer.__init__(self, *args, **kwargs)
    


class SaveFSDPModelCallback(TrainerCallback):
    def __init__(self, output_dir, callback_save_strategy, callback_save_steps=None):
        super().__init__()
        self.output_dir = output_dir
        self.callback_save_strategy = callback_save_strategy
        self.callback_save_steps = callback_save_steps

    def on_step_end(self, args, state, control, model, **kwargs):
        # Save model checkpoint at the specified interval
        if self.callback_save_strategy == 'steps':
            if state.global_step % self.callback_save_steps == 0:
                checkpoint_folder = f"checkpoint-{state.global_step}"
                output_dir = os.path.join(self.output_dir, checkpoint_folder)
                
                save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
                    cpu_state_dict = model.state_dict()
                trainer._save(self.output_dir, state_dict=cpu_state_dict) 
                print(f"Saving FSDP model checkpoint at step {state.global_step} to {output_dir}")
                
                
                #self.trainer.save_model(output_dir)
                #print(f"Saving model checkpoint to {output_dir}")
    
    def on_epoch_end(self, args, state, control, model, **kwargs):
        # Save model checkpoint at the specified interval
        if self.callback_save_strategy == 'epoch':
            checkpoint_folder = f"checkpoint-{state.global_step}"
            output_dir = os.path.join(self.output_dir, checkpoint_folder)
            
            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
                cpu_state_dict = model.state_dict()
            self.trainer._save(self.output_dir, state_dict=cpu_state_dict) 
            print(f"Saving FSDP model checkpoint at epoch {state.epoch} to {output_dir}")
