from typing import Callable, Dict, List, Optional, Tuple, Union, Any

import torch
import wandb
import torch.nn as nn
from torch.utils.data import Dataset
from peft.utils.other import transpose
import torch.nn.functional as F

from transformers import Trainer, Seq2SeqTrainingArguments
from transformers.data.data_collator import DataCollator
from transformers.trainer import (
    EvalPrediction,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    TrainerCallback,
)
from transformers.utils import is_sagemaker_mp_enabled
from peft.tuners.lora.layer import Linear as LoraLinear


# include_keywords = ["block.0", "block.4"]
include_keywords = ["encoder.block.2", "encoder.block.3",
                    "encoder.block.4"]  # for T5
# include_keywords = ["layers.27", "layers.6"]  # for Llama
do_log = True

def get_forward_hook(name):

    def hook(module, input, output):
        if wandb.run is None:
            return
        wandb.log(
            {
                f"{name}/input_mean": input[0].mean().item(),
                f"{name}/input_std": input[0].std().item(),
                f"{name}/output_mean": output.mean().item(),
                f"{name}/output_std": output.std().item(),
            },
            commit=False,
        )

    return hook


class LogTrainer_Bi(Trainer):

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: Seq2SeqTrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = (
                              None,
                              None,
                          ),
        logger=None,
        preprocess_logits_for_metrics: Optional[Callable[
            [torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )

        self.logger = logger
        self.logger.info(f"Training args: \n{self.args}")
        self.is_peft = "PeftModel" in type(model).__name__

        self.gradient_accumulation_counter = 0
        self.perturbation_exceeds_rho_count = 0

    @torch.no_grad()
    def opposite_lora2_grad(self):
        # invert the gradient of lora2
        for name, param in self.model.named_parameters():
            if param.grad is None:
                continue

            if "lora_A" in name:
                param.grad[self.args.lora1_rank:, :].mul_(-1)
            elif "lora_B" in name:
                param.grad[:, self.args.lora1_rank:].mul_(-1)

    def training_step(
            self, model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        try:
            # normalize lora2 once the parameters are updated
            if self.gradient_accumulation_counter % self.args.gradient_accumulation_steps == 0:
                self.normalize_lora2()

            model.train()
            inputs = self._prepare_inputs(inputs)

            with self.compute_loss_context_manager():
                loss = self.compute_loss(model, inputs)
            if self.args.n_gpu > 1:
                loss = loss.mean(
                )
            self.accelerator.backward(loss)

            self.gradient_accumulation_counter += 1

            # Take the negative of the gradient of lora2 before updating model weights
            if self.gradient_accumulation_counter % self.args.gradient_accumulation_steps == 0:
                self.opposite_lora2_grad()

            return loss.detach() / self.args.gradient_accumulation_steps

        except Exception as e:
            self.logger.info(f"Unexpected error in training_step: {str(e)}")
            raise e

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = self.state.epoch
        if self.args.include_num_input_tokens_seen:
            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        self.logger.info(output)
        self.control = self.callback_handler.on_log(self.args, self.state,
                                                    self.control, logs)

    @torch.no_grad()
    def normalize_lora2(self):
        """
        normalize the lora2 weights after optimizer updates them
        """
        # caculate norm
        perturbation_norm = 0

        # Compute the globel clipping norm
        for module in self.model.modules():
            if isinstance(module, LoraLinear):
                lora2_A = module.lora_A["default"].weight[
                    self.args.lora1_rank:, :].detach()
                lora2_B = module.lora_B["default"].weight[
                    :, self.args.lora1_rank:].detach()

                G_B = lora2_B.T @ lora2_B
                G_A = lora2_A @ lora2_A.T

                layer_sqnorm = (G_B * G_A.T).sum()
                perturbation_norm += layer_sqnorm

        perturbation_norm = torch.sqrt(perturbation_norm)
        eps = 1e-12
        scale = torch.sqrt(self.args.rho / (perturbation_norm + eps))

        def should_normalize(norm: float, rho: float, exceed_rho: bool):
            if hasattr(self.args, "exceed_rho") and not exceed_rho:
                return True
            else:
                return norm > rho

        if should_normalize(perturbation_norm, self.args.rho, self.args.exceed_rho):
            for module in self.model.modules():
                if isinstance(module, LoraLinear):
                    module.lora_A["default"].weight[self.args.lora1_rank:, :].mul_(
                        scale)
                    module.lora_B["default"].weight[:, self.args.lora1_rank:].mul_(
                        scale)

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir

        import os
        from transformers.utils import is_peft_available
        from transformers.utils import (
            SAFE_WEIGHTS_NAME,
            WEIGHTS_NAME,
        )
        import safetensors
        from peft import PeftModel
        import json

        os.makedirs(output_dir, exist_ok=True)
        self.logger.info(f"Saving model checkpoint to {output_dir}")

        supported_classes = (
            PreTrainedModel, ) if not is_peft_available() else (
                PreTrainedModel, PeftModel)

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, supported_classes):
            if state_dict is None:
                state_dict = self.model.state_dict()

            if isinstance(self.accelerator.unwrap_model(self.model),
                          supported_classes):
                self.accelerator.unwrap_model(self.model).save_pretrained(
                    output_dir,
                    state_dict=state_dict,
                    safe_serialization=self.args.save_safetensors)
            else:
                self.logger.info(
                    "Trainer.model is not a `PreTrainedModel`, only saving its state dict."
                )
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict,
                                                os.path.join(
                                                    output_dir,
                                                    SAFE_WEIGHTS_NAME),
                                                metadata={"format": "pt"})
                else:
                    torch.save(state_dict,
                               os.path.join(output_dir, WEIGHTS_NAME))
        else:
            # save lora1 + lora2
            self.model.save_pretrained(output_dir, state_dict=state_dict)
            os.rename(os.path.join(output_dir, "adapter_model.safetensors"),
                      os.path.join(output_dir, "adapter_model_dual.safetensors"))

            lora_state_dict = self.model.state_dict()
            new_state_dict = {}
            lora1_rank = self.args.lora1_rank

            for key, value in lora_state_dict.items():
                if 'lora_A.default' in key or 'lora_B.default' in key:
                    new_key = key.replace('.default', '')
                    if 'lora_A' in key:
                        new_value = value[:lora1_rank, :].contiguous()
                    elif 'lora_B' in key:
                        new_value = value[:, :lora1_rank].contiguous()
                    new_state_dict[new_key] = new_value

            # only save lora1
            if self.args.save_safetensors:
                safetensors.torch.save_file(new_state_dict,
                                            os.path.join(
                                                output_dir,
                                                "adapter_model.safetensors"))
            else:
                torch.save(new_state_dict,
                           os.path.join(output_dir, "adapter_model.bin"))

            adapter_config = json.load(
                open(os.path.join(output_dir, "adapter_config.json")))
            adapter_config["lora_alpha"] = lora1_rank * 2
            adapter_config["r"] = lora1_rank

            json.dump(adapter_config,
                      open(os.path.join(output_dir, "adapter_config.json"),
                           "w"),
                      indent=2)

        TRAINING_ARGS_NAME = "training_args.bin"
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    def evaluate(
        self,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        """
        Run Bi-LoRA evaluation and returns metrics. Bi-LoRA only uses lora1 during inference.
        """

        if "bi_lora" in self.args.lora_type:
            lora1_rank = self.args.lora1_rank
            logger = self.logger

            def new_forward(self, x: torch.Tensor, *args: Any,
                            **kwargs: Any) -> torch.Tensor:
                self._check_forward_args(x, *args, **kwargs)
                adapter_names = kwargs.pop("adapter_names", None)

                if self.disable_adapters:
                    if self.merged:
                        self.unmerge()
                    result = self.base_layer(x, *args, **kwargs)
                elif adapter_names is not None:
                    result = self._mixed_batch_forward(
                        x, *args, adapter_names=adapter_names, **kwargs)
                elif self.merged:
                    logger.info("merged")
                    result = self.base_layer(x, *args, **kwargs)
                else:
                    result = self.base_layer(x, *args, **kwargs)
                    torch_result_dtype = result.dtype
                    for active_adapter in self.active_adapters:
                        if active_adapter not in self.lora_A.keys():
                            continue
                        lora1_A = self.lora_A[
                            active_adapter].weight[:lora1_rank, :]
                        lora1_B = self.lora_B[
                            active_adapter].weight[:, :lora1_rank]
                        dropout = self.lora_dropout[active_adapter]
                        scaling = self.scaling[active_adapter]
                        x = x.to(lora1_A.dtype
                                 )  # [batch 32, seq_len 101, in_dim 768]
                        lora_weight = lora1_B @ lora1_A
                        x = dropout(x)

                        if not self.use_dora[active_adapter]:
                            # 1. dropout(x) @ lora1_A.T -> [batch, seq_len, r]
                            # 2. @ lora1_B.T -> [batch, seq_len, out_dim]
                            result = result + (
                                x @ lora_weight.T) * scaling
                        else:

                            # result = result + self._apply_dora(x, lora1_A, lora1_B, scaling, active_adapter)
                            magnitude = self.lora_magnitude_vector[active_adapter]
                            weight = self.get_base_layer().weight
                            weight = weight.to(x.dtype)

                            weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
                            weight_norm = weight_norm.detach()
                            mag_norm_scale = (magnitude / weight_norm).view(1, -1)
                            # breakpoint()
                            result_dora = (mag_norm_scale - 1) * (
                                F.linear(x, transpose(weight, self.fan_in_fan_out))
                            ) +  x @ lora_weight.T * mag_norm_scale * scaling

                            result = result + result_dora

                    result = result.to(torch_result_dtype)

                return result

            orig_forward = LoraLinear.forward
            LoraLinear.forward = new_forward.__get__(None, LoraLinear)

        try:
            metrics = super().evaluate(eval_dataset, ignore_keys,
                                       metric_key_prefix)
        finally:
            if "bi_lora" in self.args.lora_type:
                LoraLinear.forward = orig_forward

        return metrics


class LogTrainer(Trainer):

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: Seq2SeqTrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = (
                              None,
                              None,
                          ),
        logger=None,
        preprocess_logits_for_metrics: Optional[Callable[
            [torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )
        self.logger = logger
        self.is_peft = "PeftModel" in type(model).__name__
        if self.is_peft:
            for name, module in model.named_modules():
                if isinstance(module, LoraLinear):
                    self.scaling = module.scaling["default"]
                    break
        self.orig_A = None
        self.orig_B = None
        self.orig_W = None
        self.gradient_accumulation_counter = 0

    def training_step(
            self, model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        if not do_log:
            return super().training_step(model, inputs)
        if self.is_peft:
            if self.orig_A is None:
                self.orig_A = {}
                self.orig_B = {}
                for name, param in model.named_parameters():
                    # only act on "original" lora parameters A_0, B_0 that match the include_keywords(clone and record)
                    if param.requires_grad and any(
                        [kw in name for kw in include_keywords]):
                        if "lora_A" in name:
                            self.orig_A[name.split("lora_A.")[0]] = (
                                param.detach().clone())
                        elif "lora_B" in name:
                            self.orig_B[name.split("lora_B.")[0]] = (
                                param.detach().clone())
                # hook forward pass of lora parameters to get mean and std
                for name, module in model.named_modules():
                    if any([kw in name for kw in include_keywords
                            ]) and isinstance(module, LoraLinear):
                        hook = get_forward_hook(name)
                        module.register_forward_hook(hook)
        else:
            if self.orig_W is None:
                self.orig_W = {}
                for name, param in model.named_parameters():
                    if param.requires_grad and any(
                        [kw in name for kw in include_keywords]):
                        self.orig_W[name] = param.detach().clone()

        model.train()
        inputs = self._prepare_inputs(inputs)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training

        self.accelerator.backward(loss)
        with torch.no_grad():
            if (self.gradient_accumulation_counter %
                    self.args.gradient_accumulation_steps ==
                    self.args.gradient_accumulation_steps - 1):
                if self.is_peft:
                    A_dict = {}
                    B_dict = {}
                    for name, param in model.named_parameters():
                        if param.requires_grad and any(
                            [kw in name for kw in include_keywords]):
                            if "lora_A" in name:
                                A_dict[name.split("lora_A.")[0]] = param
                            elif "lora_B" in name:
                                B_dict[name.split("lora_B.")[0]] = param
                    assert (len(A_dict) == len(self.orig_A) == len(B_dict) ==
                            len(self.orig_B)), (
                                len(A_dict),
                                len(self.orig_A),
                                len(B_dict),
                                len(self.orig_B),
                            )
                    for key in A_dict.keys():
                        A = A_dict[key]
                        B = B_dict[key]
                        lora_r = A.shape[0]
                        A_grad = A_dict[key].grad
                        B_grad = B_dict[key].grad
                        A_0 = self.orig_A[key]
                        B_0 = self.orig_B[key]
                        A_diff = A - A_0
                        B_diff = B - B_0
                        BA = torch.matmul(B, A)
                        BA_0 = torch.matmul(B_0, A_0)
                        BA_diff = BA - BA_0
                        BA_diff_norm = torch.norm(BA_diff).item()
                        A_diff_norm = torch.norm(A_diff).item()
                        B_diff_norm = torch.norm(B_diff).item()
                        A_norm = torch.norm(A).item()
                        B_norm = torch.norm(B).item()
                        A_grad_norm = torch.norm(A_grad).item()
                        B_grad_norm = torch.norm(B_grad).item()
                        BA_singular_values = torch.svd_lowrank(
                            BA_diff.float(), q=2 * lora_r
                        )[1][:lora_r]
                        top_1_ratio = (BA_singular_values[0] /
                                       BA_singular_values.sum()).item()
                        top_4_ratio = (BA_singular_values[:4].sum() /
                                       BA_singular_values.sum()).item()
                    if wandb.run is not None:
                        wandb.log({
                            f"A_norm/{key}": A_norm,
                            f"B_norm/{key}": B_norm,
                            f"A_grad_norm/{key}": A_grad_norm,
                            f"B_grad_norm/{key}": B_grad_norm,
                            f"A_diff_norm/{key}": A_diff_norm,
                            f"B_diff_norm/{key}": B_diff_norm,
                            f"BA_diff_norm/{key}": BA_diff_norm,
                            f"scaled_BA_diff_norm/{key}":
                            self.scaling * BA_diff_norm,
                            f"BA_top_1_ratio/{key}": top_1_ratio,
                            f"BA_top_4_ratio/{key}": top_4_ratio,
                            "train/global_step": self.state.global_step,
                        })
                else:
                    W_dict = {}
                    for name, param in model.named_parameters():
                        if (param.requires_grad and any(
                            [kw in name for kw in include_keywords])
                                and len(param.shape) == 2):
                            W_dict[name] = param
                    for key in W_dict.keys():
                        W = W_dict[key]
                        W_grad = W.grad
                        W_0 = self.orig_W[key]
                        W_diff = W - W_0
                        W_diff_norm = torch.norm(W_diff).item()
                        W_norm = torch.norm(W).item()
                        W_grad_norm = torch.norm(W_grad).item()
                        U, S, V = torch.svd(W_diff.float())
                        top_1_ratio = S[0] / S.sum()
                        top_4_ratio = S[:4].sum() / S.sum()

                        if wandb.run is not None:
                            wandb.log({
                                f"W_norm/{key}": W_norm,
                                f"W_grad_norm/{key}": W_grad_norm,
                                f"W_diff_norm/{key}": W_diff_norm,
                                "train/global_step": self.state.global_step,
                                f"W_top_1_ratio/{key}": top_1_ratio.item(),
                                f"W_top_4_ratio/{key}": top_4_ratio.item(),
                            })
        self.gradient_accumulation_counter += 1

        return loss.detach() / self.args.gradient_accumulation_steps


class LogTrainer_LoRA_SAM(Trainer):

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: Seq2SeqTrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = (
                              None,
                              None,
                          ),
        logger=None,
        preprocess_logits_for_metrics: Optional[Callable[
            [torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )

        self.logger = logger
        self.logger.info(f"Training args: \n{self.args}")
        self.is_peft = "PeftModel" in type(model).__name__
        self.gradient_accumulation_counter = 0

    def training_step(
            self, model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        try:
            model.train()
            inputs = self._prepare_inputs(inputs)

            # get gradient on the clean weight
            loss = self.compute_loss_and_backward(model, inputs)
            # perturb by gradient ascent
            self.first_step_and_zero_grad(model)

            # get gradient on the perturbed weight
            loss = self.compute_loss_and_backward(model, inputs)
            # restore and accumulate gradient on the perturbed weight
            self.restore_params(model)

            self.gradient_accumulation_counter += 1

            return loss.detach() / self.args.gradient_accumulation_steps
        except Exception as e:
            raise e

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = self.state.epoch
        if self.args.include_num_input_tokens_seen:
            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        self.logger.info(output)
        self.control = self.callback_handler.on_log(self.args, self.state,
                                                    self.control, logs)

    def compute_loss_and_backward(self, model, inputs):
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)
        if self.args.n_gpu > 1:
            loss = loss.mean()
        self.accelerator.backward(loss)

        return loss

    def first_step_and_zero_grad(self, model):
        # compute grad norm
        grad_norm = torch.norm(torch.stack([
            torch.norm(p.grad) for p in model.parameters() if p.requires_grad
        ]),
                               p=2)
        scale = self.args.rho / (grad_norm + 1e-12)

        # gradient ascent
        for p in model.parameters():
            if p.requires_grad:
                if not hasattr(p, "unattacked_p") or p.unattacked_p is None:
                    p.unattacked_p = p.data.clone()
                perturbation = p.grad * scale.to(p)
                p.data.add_(perturbation)

                p.grad = None

    def restore_params(self, model):
        for p in model.parameters():
            if p.requires_grad:
                # restore original parameters
                if hasattr(p, "unattacked_p"):
                    p.data.copy_(p.unattacked_p)
                    p.unattacked_p = None

                # accumulate gradient on the perturbed weight
                if self.gradient_accumulation_counter % self.args.gradient_accumulation_steps == 0:
                    p.accumulated_grad = p.grad.clone()
                else:
                    p.accumulated_grad += p.grad.clone()

                # clean up the gradient to prevent incorrect gradient accmulation in the first backward phase
                p.grad = None

                # on the last micro batch, update the parameters
                if (self.gradient_accumulation_counter + 1) % self.args.gradient_accumulation_steps == 0:
                    p.grad = p.accumulated_grad.clone()
                    p.accumulated_grad = None
