from transformers import TrainerCallback
import os
import torch
import torch.distributed as dist
class TrainingMonitorCallback(TrainerCallback):
    def __init__(self, output_dir='out'):
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
        self.save_format = 'hf'
        self.resume_training = False
        self.epoch_begin = None

    def on_epoch_end(self, args, state, control, **kwargs):
        if not dist.is_initialized() or dist.get_rank() == 0:
            model = kwargs['model']
            epoch = state.epoch
            if self.resume_training:
                epoch = self.epoch_begin
            save_path = os.path.join(self.output_dir, f"llama2_epoch_{int(epoch)}")
            if self.save_format=='hf':
                model.save_pretrained(save_path, safe_serialization=False)
            else:
                torch.save(model.state_dict(), save_path)
            print(f"Model saved at {save_path} after epoch {int(epoch)}")
        
        
from trl import SFTTrainer,SFTConfig      
from torch.nn import functional as F      
from typing import Callable, Dict, List, Optional, Tuple, Union
from transformers.trainer_utils import EvalPrediction
from transformers import (
    DataCollator,
    PreTrainedModel,
    PreTrainedTokenizerBase,

)
import torch.nn as nn
from datasets import Dataset
from torch.nn.utils import clip_grad_norm_

class DistillationTrainer(SFTTrainer):
    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        args: Optional[SFTConfig] = None,
        data_collator: Optional[DataCollator] = None,  # 数据整理器
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        peft_config: Optional["PeftConfig"] = None,
        dataset_text_field: Optional[str] = None,
        packing: Optional[bool] = False,
        formatting_func: Optional[Callable] = None,
        max_seq_length: Optional[int] = None,
        infinite: Optional[bool] = None,
        num_of_sequences: Optional[int] = None,
        chars_per_token: Optional[float] = None,
        dataset_num_proc: Optional[int] = None,
        dataset_batch_size: Optional[int] = None,
        neftune_noise_alpha: Optional[float] = None,
        model_init_kwargs: Optional[Dict] = None,
        dataset_kwargs: Optional[Dict] = None,
        eval_packing: Optional[bool] = None,
        
        teacher_model: Optional[PreTrainedModel] = None,  

    ):
        self.teacher_model = teacher_model
        
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            peft_config=peft_config,
            dataset_text_field=dataset_text_field,
            packing=packing,
            formatting_func=formatting_func,
            max_seq_length=max_seq_length,
            infinite=infinite,
            num_of_sequences=num_of_sequences,
            chars_per_token=chars_per_token,
            dataset_num_proc=dataset_num_proc,
            dataset_batch_size=dataset_batch_size,
            neftune_noise_alpha=neftune_noise_alpha,
            model_init_kwargs=model_init_kwargs,
            dataset_kwargs=dataset_kwargs,
            eval_packing=eval_packing
        )

    @staticmethod
    def distillation_loss(student_logits, teacher_logits):
        teacher_probs = F.softmax(teacher_logits , dim=-1)
        student_log_probs = F.log_softmax(student_logits , dim=-1)
        loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') 
        return loss

    def compute_loss(self, model, inputs, return_outputs=False):
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits

        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs) 
        teacher_logits = teacher_outputs.logits

        loss = self.distillation_loss(student_logits, teacher_logits)*0.01

        return (loss, student_outputs) if return_outputs else loss
    