from accelerate.utils import DummyOptim
from peft import PeftModel
from transformers import Trainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import is_peft_available


class PITPretrainTrainer(Trainer):
    def __init__(self, **kwargs):
        self.protein_model_fixed = kwargs.pop("protein_model_fixed", True)
        self.text_model_fixed = kwargs.pop("text_model_fixed", True)
        self.lr_ratio = kwargs.pop("lr_ratio", 0.1)
        self.use_moe = kwargs.pop("use_moe", False)
        super().__init__(**kwargs)

    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """

        if self.protein_model_fixed and not isinstance(self.model.protein_model, PeftModel):
            for param in self.model.protein_model.parameters():
                param.requires_grad = False
        if self.text_model_fixed and not isinstance(self.model.protein_model, PeftModel):
            for param in self.model.text_model.parameters():
                param.requires_grad = False

        decay_parameters = self.get_decay_parameter_names(self.model)

        ratio_parameters = []

        if not self.protein_model_fixed:
            ratio_parameters += [n for n, p in self.model.protein_model.named_parameters()]
        if not self.text_model_fixed:
            ratio_parameters += [n for n, p in self.model.text_model.named_parameters()]

        if self.protein_model_fixed and self.text_model_fixed:
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                    "name": "decay"
                },
                {
                    "params": [
                        p for n, p in self.model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                    "name": "no_decay"
                },
            ]
        else:
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p for n, p in self.model.named_parameters() if
                        (n in decay_parameters and n in ratio_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                    "lr": self.lr_ratio * self.args.learning_rate,
                    "name": "decay_ratio"
                },
                {
                    "params": [
                        p for n, p in self.model.named_parameters() if
                        (n not in decay_parameters and n in ratio_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                    "lr": self.lr_ratio * self.args.learning_rate,
                    "name": "no_decay_ratio"
                },
                {
                    "params": [
                        p for n, p in self.model.named_parameters() if
                        (n in decay_parameters and n not in ratio_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                    "name": "decay_no_ratio"
                },
                {
                    "params": [
                        p for n, p in self.model.named_parameters() if
                        (n not in decay_parameters and n not in ratio_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                    "name": "no_decay_no_ratio"
                }
            ]

        if self.use_moe:
            from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
            optimizer_grouped_parameters = split_params_into_different_moe_groups_for_optimizer(
                optimizer_grouped_parameters)

        optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
        self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
        return self.optimizer
