from trl import SFTTrainer
from typing import Callable, Dict, List, Optional, Tuple, Union
import dataclasses
import inspect
import warnings
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union

import datasets
import torch
import torch.nn as nn
from accelerate.state import PartialState
from datasets import Dataset
from datasets.arrow_writer import SchemaInferenceError
from datasets.builder import DatasetGenerationError
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollator,
    DataCollatorForLanguageModeling,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from trl import is_peft_available
from peft import PeftModel
class MyTrainer(SFTTrainer):
    def __init__(
            self,
            model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
            args: Optional[TrainingArguments] = None,
            data_collator: Optional[DataCollator] = None,  # type: ignore
            train_dataset: Optional[Dataset] = None,
            eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
            tokenizer: Optional[PreTrainedTokenizerBase] = None,
            model_init: Optional[Callable[[], PreTrainedModel]] = None,
            compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
            callbacks: Optional[List[TrainerCallback]] = None,
            optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
            preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
            peft_config= None,
            dataset_text_field: Optional[str] = None,
            packing: Optional[bool] = False,
            formatting_func: Optional[Callable] = None,
            max_seq_length: Optional[int] = None,
            infinite: Optional[bool] = None,
            num_of_sequences: Optional[int] = 1024,
            chars_per_token: Optional[float] = 3.6,
            dataset_num_proc: Optional[int] = None,
            dataset_batch_size: int = 1000,
            neftune_noise_alpha: Optional[float] = None,
            model_init_kwargs: Optional[Dict] = None,
            dataset_kwargs: Optional[Dict] = None,
            eval_packing: Optional[bool] = None,
        ):
        super().__init__(model,args,data_collator,train_dataset,eval_dataset,tokenizer,model_init,compute_metrics,callbacks,
                   optimizers,preprocess_logits_for_metrics,peft_config,dataset_text_field,packing,formatting_func,max_seq_length,
                   infinite,num_of_sequences,chars_per_token,dataset_num_proc,dataset_batch_size,neftune_noise_alpha,model_init_kwargs,
                   dataset_kwargs,eval_packing)
        self.myTokenizer=tokenizer
    @wraps(SFTTrainer.train)
    def train(self, *args, **kwargs):
        # Activate neftune right before training.
        if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
            self.model = self._trl_activate_neftune(self.model)

        output = super().train(*args, **kwargs)
        pred=self.myTokenizer.batch_decode(torch.argmax(output.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        print(pred)
        
        
        
        # After training we make sure to retrieve back the original forward pass method
        # for the embedding layer by removing the forward post hook.
        if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
            unwrapped_model = unwrap_model(self.model)
            if is_peft_available() and isinstance(unwrapped_model, PeftModel):
                embeddings = unwrapped_model.base_model.model.get_input_embeddings()
            else:
                embeddings = unwrapped_model.get_input_embeddings()

            self.neftune_hook_handle.remove()
            del embeddings.neftune_noise_alpha

        return output