from transformers import Seq2SeqTrainer
import math
import torch
from transformers.trainer import _is_peft_model
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
)
class PruningTrainer(Seq2SeqTrainer):
    
    def __init__(self, sparsity_warmup_steps=25000, target_sparsity=0.5, warmup_type='linear', *args, **kwargs):    
        super().__init__(*args, **kwargs)
        self.start_sparsity = 0.0
        self.target_sparsity = target_sparsity
        self.sparsity_warmup_steps = sparsity_warmup_steps
        self.warmup_type = warmup_type
        
    def get_current_target_sparsity(self, global_step):
        if global_step < self.sparsity_warmup_steps:
            if self.warmup_type == 'linear':
                return (
                    self.start_sparsity + (self.target_sparsity - self.start_sparsity) * 
                    global_step / self.sparsity_warmup_steps
                )
            elif self.warmup_type == 'logarithmic':
                log_one_minus_sparsity = math.log(1 - self.start_sparsity) + (math.log(1 - self.target_sparsity) - 
                    math.log(1 - self.start_sparsity)) * global_step / self.sparsity_warmup_steps
                return 1 - math.exp(log_one_minus_sparsity)
            else:
                raise ValueError(f'Unknown warmup type: {self.warmup_type}')
        else:
            return self.target_sparsity
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        if self.model_accepts_loss_kwargs:
            loss_kwargs = {}
            if num_items_in_batch is not None:
                loss_kwargs["num_items_in_batch"] = num_items_in_batch
            inputs = {**inputs, **loss_kwargs}
            
        target_sparsity = self.get_current_target_sparsity(self.state.global_step)
        if self.state.global_step < self.sparsity_warmup_steps:
            try:
                model.set_pruning_ratio(1-target_sparsity)
            except:
                model.module.set_pruning_ratio(1-target_sparsity)
        # print(self.state.global_step, self.sparsity_warmup_steps)
        # elif self.state.global_step == self.sparsity_warmup_steps:
            # try:
            #     model.enable_ensemble()
            # except:
            #     model.module.enable_ensemble()
        outputs = model(**inputs, output_attentions=True)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = self.accelerator.unwrap_model(model)
            if _is_peft_model(unwrapped_model):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            # User-defined compute_loss function
            if self.compute_loss_func is not None:
                loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
            elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        if (
            self.args.average_tokens_across_devices
            and (self.model_accepts_loss_kwargs or self.compute_loss_func)
            and num_items_in_batch is not None
        ):
            loss *= self.accelerator.num_processes

        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         print(param)
        
        
        return (loss, outputs) if return_outputs else loss
