from typing import Callable, Dict, List, Optional, Tuple, Union, Any

import torch
import wandb
import torch.nn as nn
from torch.utils.data import Dataset
from peft.utils.other import transpose
import torch.nn.functional as F

from transformers import Trainer, Seq2SeqTrainingArguments
from transformers.data.data_collator import DataCollator
from transformers.trainer import (
    EvalPrediction,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    TrainerCallback,
)
from transformers.utils import is_sagemaker_mp_enabled
from peft.tuners.lora.layer import Linear as LoraLinear


# include_keywords = ["block.0", "block.4"]
include_keywords = ["encoder.block.2", "encoder.block.3",
                    "encoder.block.4"]  # for T5
# include_keywords = ["layers.27", "layers.6"]  # for Llama
do_log = True

def get_forward_hook(name):

    def hook(module, input, output):
        wandb.log(
            {
                f"{name}/input_mean": input[0].mean().item(),
                f"{name}/input_std": input[0].std().item(),
                f"{name}/output_mean": output.mean().item(),
                f"{name}/output_std": output.std().item(),
            },
            commit=False,
        )

    return hook


class LogTrainer_Bi(Trainer):

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: Seq2SeqTrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = (
                              None,
                              None,
                          ),
        logger=None,
        preprocess_logits_for_metrics: Optional[Callable[
            [torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )

        self.logger = logger
        self.logger.info(f"Training args: \n{self.args}")
        self.is_peft = "PeftModel" in type(model).__name__

        self.gradient_accumulation_counter = 0
        self.perturbation_exceeds_rho_count = 0


    @torch.no_grad()
    def opposite_lora2_grad(self):
        # invert the gradient of lora2
        for name, param in self.model.named_parameters():
            if param.grad is None:
                break

            if "lora_A" in name:
                param.grad[self.args.lora1_rank:, :].mul_(-1)
            else:  # lora_B
                param.grad[:, self.args.lora1_rank:].mul_(-1)

    def training_step(
            self, model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        try:
            # normalize lora2 once the parameters are updated
            if self.gradient_accumulation_counter % self.args.gradient_accumulation_steps == 0:
                self.normalize_lora2()

            model.train()
            inputs = self._prepare_inputs(inputs)
            ####################

            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)
            if self.args.n_gpu > 1:
                loss = loss.mean(
                )
            self.accelerator.backward(loss)

            self.gradient_accumulation_counter += 1

            # Take the negative of the gradient of lora2 before updating model weights
            if self.gradient_accumulation_counter % self.args.gradient_accumulation_steps == 0:
                self.opposite_lora2_grad()

            return loss.detach() / self.args.gradient_accumulation_steps

        except Exception as e:
            self.logger.info(f"Unexpected error in training_step: {str(e)}")
            raise e

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = self.state.epoch
        if self.args.include_num_input_tokens_seen:
            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        self.logger.info(output)
        self.control = self.callback_handler.on_log(self.args, self.state,
                                                    self.control, logs)

    @torch.no_grad()
    def normalize_lora2(self):
        """
        normalize the lora2 weights after optimizer updates them
        """
        # caculate norm
        perturbation_norm = 0

        for module in self.model.modules():
            if isinstance(module, LoraLinear):
                lora2_A = module.lora_A["default"].weight[
                    self.args.lora1_rank:, :].detach()
                lora2_B = module.lora_B["default"].weight[
                    :, self.args.lora1_rank:].detach()

                perturbation_norm += ((lora2_B @ lora2_A)**2).sum()

        perturbation_norm = torch.sqrt(perturbation_norm)
        eps = 1e-12
        scale = torch.sqrt(self.args.rho / (perturbation_norm + eps))

        def should_normalize(norm: float, rho: float, exceed_rho: bool):
            if hasattr(self.args, "exceed_rho") and not exceed_rho:
                return True
            else:
                return norm > rho

        if should_normalize(perturbation_norm, self.args.rho, self.args.exceed_rho):
            for module in self.model.modules():
                if isinstance(module, LoraLinear):
                    module.lora_A["default"].weight[self.args.lora1_rank:, :].mul_(
                        scale)
                    module.lora_B["default"].weight[:, self.args.lora1_rank:].mul_(
                        scale)

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir

        import os
        from transformers.utils import is_peft_available
        from transformers.utils import (
            SAFE_WEIGHTS_NAME,
            WEIGHTS_NAME,
        )
        import safetensors
        from peft import PeftModel
        import json

        os.makedirs(output_dir, exist_ok=True)
        self.logger.info(f"Saving model checkpoint to {output_dir}")

        supported_classes = (
            PreTrainedModel, ) if not is_peft_available() else (
                PreTrainedModel, PeftModel)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, supported_classes):
            if state_dict is None:
                state_dict = self.model.state_dict()

            if isinstance(self.accelerator.unwrap_model(self.model),
                          supported_classes):
                self.accelerator.unwrap_model(self.model).save_pretrained(
                    output_dir,
                    state_dict=state_dict,
                    safe_serialization=self.args.save_safetensors)
            else:
                self.logger.info(
                    "Trainer.model is not a `PreTrainedModel`, only saving its state dict."
                )
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict,
                                                os.path.join(
                                                    output_dir,
                                                    SAFE_WEIGHTS_NAME),
                                                metadata={"format": "pt"})
                else:
                    torch.save(state_dict,
                               os.path.join(output_dir, WEIGHTS_NAME))
        else:
            lora_state_dict = self.model.state_dict()
            new_state_dict = {}
            lora1_rank = self.args.lora1_rank

            for key, value in lora_state_dict.items():
                if 'lora_A.default' in key or 'lora_B.default' in key:
                    new_key = key.replace('.default', '')
                    if 'lora_A' in key:
                        new_value = value[:lora1_rank, :].contiguous()
                    elif 'lora_B' in key:
                        new_value = value[:, :lora1_rank].contiguous()
                    new_state_dict[new_key] = new_value

            # only save lora1
            if self.args.save_safetensors:
                safetensors.torch.save_file(new_state_dict,
                                            os.path.join(
                                                output_dir,
                                                "adapter_model.safetensors"))
            else:
                torch.save(new_state_dict,
                           os.path.join(output_dir, "adapter_model.bin"))

            adapter_config = json.load(
                open(os.path.join(output_dir, "adapter_config.json")))
            adapter_config["lora_alpha"] = lora1_rank * 2
            adapter_config["r"] = lora1_rank

            json.dump(adapter_config,
                      open(os.path.join(output_dir, "adapter_config.json"),
                           "w"),
                      indent=2)

        # Good practice: save your training arguments together with the trained model
        TRAINING_ARGS_NAME = "training_args.bin"
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    def evaluate(
        self,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
        (pass it to the init `compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*):
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
                evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
                `__len__` method.

                <Tip>

                If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
                separate evaluations on each dataset. This can be useful to monitor how training affects other
                datasets or simply to get a more fine-grained evaluation.
                When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
                of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
                `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
                loss on `data1` and `metric_for_best_model="eval_data1_loss"` for the loss on `data2`.

                </Tip>

            ignore_keys (`List[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
        """

        # only use lora1 to do evaluate and inference
        if "bi_lora" in self.args.lora_type:
            lora1_rank = self.args.lora1_rank
            logger = self.logger

            def new_forward(self, x: torch.Tensor, *args: Any,
                            **kwargs: Any) -> torch.Tensor:
                self._check_forward_args(x, *args, **kwargs)
                adapter_names = kwargs.pop("adapter_names", None)

                if self.disable_adapters:
                    if self.merged:
                        self.unmerge()
                    result = self.base_layer(x, *args, **kwargs)
                elif adapter_names is not None:
                    result = self._mixed_batch_forward(
                        x, *args, adapter_names=adapter_names, **kwargs)
                elif self.merged:
                    logger.info("merged")
                    result = self.base_layer(x, *args, **kwargs)
                else:
                    result = self.base_layer(x, *args, **kwargs)
                    torch_result_dtype = result.dtype
                    for active_adapter in self.active_adapters:
                        if active_adapter not in self.lora_A.keys():
                            continue
                        lora1_A = self.lora_A[
                            active_adapter].weight[:lora1_rank, :]
                        lora1_B = self.lora_B[
                            active_adapter].weight[:, :lora1_rank]
                        dropout = self.lora_dropout[active_adapter]
                        scaling = self.scaling[active_adapter]
                        x = x.to(lora1_A.dtype
                                 )  # [batch 32, seq_len 101, in_dim 768]
                        lora_weight = lora1_B @ lora1_A
                        x = dropout(x)

                        if not self.use_dora[active_adapter]:
                            # 1. dropout(x) @ lora1_A.T -> [batch, seq_len, r]
                            # 2. @ lora1_B.T -> [batch, seq_len, out_dim]
                            result = result + (
                                x @ lora_weight.T) * scaling
                        else:

                            # result = result + self._apply_dora(x, lora1_A, lora1_B, scaling, active_adapter)
                            magnitude = self.lora_magnitude_vector[active_adapter]
                            weight = self.get_base_layer().weight
                            weight = weight.to(x.dtype)

                            weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
                            weight_norm = weight_norm.detach()
                            mag_norm_scale = (magnitude / weight_norm).view(1, -1)
                            # breakpoint()
                            result_dora = (mag_norm_scale - 1) * (
                                F.linear(x, transpose(weight, self.fan_in_fan_out))
                            ) +  x @ lora_weight.T * mag_norm_scale * scaling

                            result = result + result_dora

                    result = result.to(torch_result_dtype)

                return result

            orig_forward = LoraLinear.forward
            LoraLinear.forward = new_forward.__get__(None, LoraLinear)

        try:
            metrics = super().evaluate(eval_dataset, ignore_keys,
                                       metric_key_prefix)
        finally:
            if "bi_lora" in self.args.lora_type:
                LoraLinear.forward = orig_forward

        return metrics


class LogTrainer(Trainer):

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: Seq2SeqTrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = (
                              None,
                              None,
                          ),
        logger=None,
        preprocess_logits_for_metrics: Optional[Callable[
            [torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )
        self.logger = logger
        self.is_peft = "PeftModel" in type(model).__name__
        if self.is_peft:
            for name, module in model.named_modules():
                if isinstance(module, LoraLinear):
                    self.scaling = module.scaling["default"]
                    break
        self.orig_A = None
        self.orig_B = None
        self.orig_W = None
        self.gradient_accumulation_counter = 0

    def training_step(
            self, model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        if not do_log:
            return super().training_step(model, inputs)
        if self.is_peft:
            if self.orig_A is None:
                self.orig_A = {}
                self.orig_B = {}
                for name, param in model.named_parameters():
                    # only act on "original" lora parameters A_0, B_0 that match the include_keywords(clone and record)
                    if param.requires_grad and any(
                        [kw in name for kw in include_keywords]):
                        if "lora_A" in name:
                            self.orig_A[name.split("lora_A.")[0]] = (
                                param.detach().clone())
                        elif "lora_B" in name:
                            self.orig_B[name.split("lora_B.")[0]] = (
                                param.detach().clone())
                # hook forward pass of lora parameters to get mean and std
                for name, module in model.named_modules():
                    if any([kw in name for kw in include_keywords
                            ]) and isinstance(module, LoraLinear):
                        hook = get_forward_hook(name)
                        module.register_forward_hook(hook)
        else:
            if self.orig_W is None:
                self.orig_W = {}
                for name, param in model.named_parameters():
                    if param.requires_grad and any(
                        [kw in name for kw in include_keywords]):
                        self.orig_W[name] = param.detach().clone()

        model.train()
        inputs = self._prepare_inputs(inputs)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training

        self.accelerator.backward(loss)
        with torch.no_grad():
            if (self.gradient_accumulation_counter %
                    self.args.gradient_accumulation_steps ==
                    self.args.gradient_accumulation_steps - 1):
                if self.is_peft:
                    A_dict = {}
                    B_dict = {}
                    for name, param in model.named_parameters():
                        if param.requires_grad and any(
                            [kw in name for kw in include_keywords]):
                            if "lora_A" in name:
                                A_dict[name.split("lora_A.")[0]] = param
                            elif "lora_B" in name:
                                B_dict[name.split("lora_B.")[0]] = param
                    assert (len(A_dict) == len(self.orig_A) == len(B_dict) ==
                            len(self.orig_B)), (
                                len(A_dict),
                                len(self.orig_A),
                                len(B_dict),
                                len(self.orig_B),
                            )
                    for key in A_dict.keys():
                        A = A_dict[key]
                        B = B_dict[key]
                        lora_r = A.shape[0]
                        A_grad = A_dict[key].grad
                        B_grad = B_dict[key].grad
                        A_0 = self.orig_A[key]
                        B_0 = self.orig_B[key]
                        A_diff = A - A_0
                        B_diff = B - B_0
                        BA = torch.matmul(B, A)
                        BA_0 = torch.matmul(B_0, A_0)
                        # \Delta W = BA - BA_0
                        BA_diff = BA - BA_0
                        BA_diff_norm = torch.norm(BA_diff).item()
                        A_diff_norm = torch.norm(A_diff).item()
                        B_diff_norm = torch.norm(B_diff).item()
                        A_norm = torch.norm(A).item()
                        B_norm = torch.norm(B).item()
                        A_grad_norm = torch.norm(A_grad).item()
                        B_grad_norm = torch.norm(B_grad).item()
                        # BA_singular_values = torch.svd(BA_diff.float(), compute_uv=False).S[:lora_r]
                        BA_singular_values = torch.svd_lowrank(
                            BA_diff.float(), q=2 * lora_r
                        )[1][:
                             lora_r]  # precision of low_rank SVD is related to q
                        top_1_ratio = (BA_singular_values[0] /
                                       BA_singular_values.sum()).item()
                        top_4_ratio = (BA_singular_values[:4].sum() /
                                       BA_singular_values.sum()).item(
                                       )  # intensity of top-4 main directions
                        wandb.log({
                            f"A_norm/{key}": A_norm,
                            f"B_norm/{key}": B_norm,
                            f"A_grad_norm/{key}": A_grad_norm,
                            f"B_grad_norm/{key}": B_grad_norm,
                            f"A_diff_norm/{key}": A_diff_norm,
                            f"B_diff_norm/{key}": B_diff_norm,
                            f"BA_diff_norm/{key}": BA_diff_norm,
                            f"scaled_BA_diff_norm/{key}":
                            self.scaling * BA_diff_norm,
                            f"BA_top_1_ratio/{key}": top_1_ratio,
                            f"BA_top_4_ratio/{key}": top_4_ratio,
                            "train/global_step": self.state.global_step,
                        })
                else:
                    W_dict = {}
                    for name, param in model.named_parameters():
                        if (param.requires_grad and any(
                            [kw in name for kw in include_keywords])
                                and len(param.shape) == 2):
                            W_dict[name] = param
                    for key in W_dict.keys():
                        W = W_dict[key]
                        W_grad = W.grad
                        W_0 = self.orig_W[key]
                        W_diff = W - W_0
                        W_diff_norm = torch.norm(W_diff).item()
                        W_norm = torch.norm(W).item()
                        W_grad_norm = torch.norm(W_grad).item()
                        U, S, V = torch.svd(W_diff.float())
                        top_1_ratio = S[0] / S.sum()
                        top_4_ratio = S[:4].sum() / S.sum()
                        wandb.log({
                            f"W_norm/{key}": W_norm,
                            f"W_grad_norm/{key}": W_grad_norm,
                            f"W_diff_norm/{key}": W_diff_norm,
                            "train/global_step": self.state.global_step,
                            f"W_top_1_ratio/{key}": top_1_ratio.item(),
                            f"W_top_4_ratio/{key}": top_4_ratio.item(),
                        })
        self.gradient_accumulation_counter += 1

        return loss.detach() / self.args.gradient_accumulation_steps


class LogTrainer_LoRA_SAM(Trainer):

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: Seq2SeqTrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = (
                              None,
                              None,
                          ),
        logger=None,
        preprocess_logits_for_metrics: Optional[Callable[
            [torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )

        self.logger = logger
        self.logger.info(f"Training args: \n{self.args}")
        self.is_peft = "PeftModel" in type(model).__name__
        self.gradient_accumulation_counter = 0

    def training_step(
            self, model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        try:
            model.train()
            inputs = self._prepare_inputs(inputs)

            loss = self.compute_loss_and_backward(model, inputs)
            self.first_step_and_zero_grad(model)

            # get grad on w + (B + eps_b)(A + eps_a)
            loss = self.compute_loss_and_backward(model, inputs)
            self.restore_params(model)

            self.gradient_accumulation_counter += 1

            return loss.detach() / self.args.gradient_accumulation_steps
        except Exception as e:
            raise e

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = self.state.epoch
        if self.args.include_num_input_tokens_seen:
            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        self.logger.info(output)
        self.control = self.callback_handler.on_log(self.args, self.state,
                                                    self.control, logs)

    def compute_loss_and_backward(self, model, inputs):
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)
        if self.args.n_gpu > 1:
            loss = loss.mean()
        self.accelerator.backward(loss)

        return loss

    def first_step_and_zero_grad(self, model):
        # compute grad norm
        grad_norm = torch.norm(torch.stack([
            torch.norm(p.grad) for p in model.parameters() if p.requires_grad
        ]),
                               p=2)
        scale = self.args.rho / (grad_norm + 1e-12)

        # gradient ascent
        for p in model.parameters():
            if p.requires_grad:
                p.unattacked_p = p.data.clone()
                perturbation = p.grad * scale.to(p)
                p.data.add_(perturbation)

                p.grad = None

    def restore_params(self, model):
        for p in model.parameters():
            if p.requires_grad:
                if hasattr(p, "unattacked_p"):
                    p.data.copy_(p.unattacked_p)
                    p.unattacked_p = None

                if self.gradient_accumulation_counter % self.args.gradient_accumulation_steps == 0:
                    p.accumulated_grad = p.grad.clone()
                else:
                    p.accumulated_grad += p.grad.clone()

                if (self.gradient_accumulation_counter + 1) % self.args.gradient_accumulation_steps == 0:
                    p.grad = p.accumulated_grad.clone()
                    p.accumulated_grad = None
