#%%
import torch, pickle
from peft_utils.layers.linear import low_rank_linear
from peft_utils.layers.conv import low_rank_CP
# from peft.tuners.adalora.layer import AdaLoraLayer
# from loralib.layers import LoRALayer as AdaLoraLayer
from peft.tuners.lora import LoraLayer as AdaLoraLayer
from transformers import Trainer
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import inspect
from transformers.utils import logging
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from transformers.file_utils import (
    WEIGHTS_NAME,
    is_datasets_available,
)
if is_datasets_available():
    import datasets
from peft_utils.adapters_utils import wrap_classinstance_with_lr_methods,get_hypergradient,activate_lower_level,activate_upper_level

logger = logging.get_logger(__name__)
from accelerate import Accelerator

import wandb


def set_debug_apis(state: bool = False):
    torch.autograd.profiler.profile(enabled=state)
    torch.autograd.profiler.emit_nvtx(enabled=state)
    torch.autograd.set_detect_anomaly(mode=state)

# from transformers.integrations import (  # isort: split
#     default_hp_search_backend,
#     get_reporting_integration_callbacks,
#     hp_params,
#     is_fairscale_available,
#     is_optuna_available,
#     is_ray_tune_available,
#     run_hp_search_optuna,
#     run_hp_search_ray,
#     init_deepspeed,
# )

from transformers.file_utils import (
    WEIGHTS_NAME,
    is_apex_available,
    is_datasets_available,
    # is_in_notebook,
    # is_sagemaker_distributed_available,
    # is_torch_tpu_available,
    # is_training_run_on_sagemaker,
)

if is_apex_available():
    from apex import amp

from packaging import version
if version.parse(torch.__version__) >= version.parse("1.6"):
    _is_native_amp_available = True
    from torch.cuda.amp import autocast

# if is_fairscale_available():
#     import fairscale
#     from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
#     from fairscale.optim import OSS
#     from fairscale.optim.grad_scaler import ShardedGradScaler


@torch.no_grad()
def basis_norm_check(low_rank_layers):
    """
    Check feasibility of basis in the Oblique manifold
    """
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear):
            mask = (l.s != 0.0).bool()
            print(
                f"check U,V distance to constraint: {torch.linalg.norm(torch.linalg.norm(l.us[:,mask],dim = 0)-1),torch.linalg.norm(torch.linalg.norm(l.us[:,mask],dim = 0)-1)}"
            )
        elif isinstance(l, low_rank_CP):
            mask = (l.s != 0.0).bool()
            print(
                f"check us distance to constraint: {[torch.linalg.norm(torch.linalg.norm(u[:,mask],dim = 0)-1) for u in l.us]}"
            )
        elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
            mask = (l.lora_E.squeeze() != 0.0).bool()
            print(
                f"check us distance to constraint: {torch.linalg.norm(torch.linalg.norm(l.lora_A[mask,:],dim = 1)-1),torch.linalg.norm(torch.linalg.norm(l.lora_B[:,mask],dim = 0)-1)}"
            )
        elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
            adapter_name = l.adapter_name
            mask = (l.lora_E[adapter_name].squeeze() != 0.0).bool()
            print(
                f"check us distance to constraint: {torch.linalg.norm(torch.linalg.norm(l.lora_A[adapter_name][mask,:],dim = 1)-1),torch.linalg.norm(torch.linalg.norm(l.lora_B[adapter_name][:,mask],dim = 0)-1)}"
            )


@torch.no_grad()
def stiefel_constraint_check(low_rank_layers):
    """
    Check feasibility of basis in the Oblique manifold
    """
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear):
            print(
                f"check U,V distance to constraint: {torch.linalg.norm(l.us.T@l.us-torch.eye(l.us.shape[1],device = l.us.device)),torch.linalg.norm(l.vs.T@l.vs-torch.eye(l.vs.shape[1],device = l.vs.device))}"
            )
        elif isinstance(l, low_rank_CP):
            print(
                f"check us distance to constraint: {[torch.linalg.norm(u.T@u-torch.eye(u.shape[1],device = u.device)) for u in l.us]}"
            )
        elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
            print(
                f"check us distance to constraint: {torch.linalg.norm(l.lora_A@l.lora_A.T-torch.eye(l.lora_A.shape[0],device = l.lora_A.device)),torch.linalg.norm(l.lora_B.T@l.lora_B-torch.eye(l.lora_B.shape[1],device = l.lora_B.device))}"
            )
        elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
            adapter_name = l.adapter_name
            print(
                f"check us distance to constraint: {torch.linalg.norm(l.lora_A[adapter_name]@l.lora_A[adapter_name].T-torch.eye(l.lora_A[adapter_name].shape[0],device = l.lora_A[adapter_name].device)),torch.linalg.norm(l.lora_B[adapter_name].T@l.lora_B[adapter_name]-torch.eye(l.lora_B[adapter_name].shape[1],device = l.lora_B[adapter_name].device))}"
            )


def load_trainer(path):
    file = open(path, "rb")
    t = pickle.load(file)
    return t


@torch.no_grad()
def prune(s, cr, s_measure=None):
    """
    Prunes the cr% of smallest entries of s and returns the pruned s
    s: global s vector
    cr: global neural network compression ratio
    """
    n_to_cut = min([int(torch.ceil((cr) * len(s))), len(s) - 1])
    if s_measure != None:
        s_measure_copy = s_measure.clone()
    else:
        s_measure_copy = s.clone()
    s_copy = s.clone()
    s_sort, _ = torch.sort(s_measure_copy, descending=False)
    s_pruned = (s_measure_copy > s_sort[n_to_cut]) * s_copy
    return s_pruned


@torch.no_grad()
def get_s(low_rank_layers):
    """
    low_rank_layers: list of low_rank layers
    returns the vectorized version of global s
    """
    s = None
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear) or isinstance(l, low_rank_CP):
            if s == None:
                s = l.s
            else:
                s = torch.concat([s, l.s])
        elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:   ## this test
            if s == None:
                s = l.lora_E.squeeze()
            else:
                s = torch.concat([s, l.lora_E.squeeze()])
        elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):   ### peft package
            if s == None:
                s = l.lora_E[l.adapter_name].squeeze()
            else:
                s = torch.concat([s, l.lora_E[l.adapter_name].squeeze()])
        else:
            raise TypeError()
    return s


@torch.no_grad()
def get_grad_s(low_rank_layers):
    """
    low_rank_layers: list of low_rank layers
    returns the vectorized version of grad_s
    """
    grads = None
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear) or isinstance(l, low_rank_CP):
            if grads == None:
                grads = l.s.grad
            else:
                grads = torch.concat([grads, l.s.grad])
        elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
            if grads == None:
                grads = l.lora_E.grad.squeeze()
            else:
                grads = torch.concat([grads, l.lora_E.grad.squeeze()])
        elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
            if grads == None:
                grads = l.lora_E[l.adapter_name].grad.squeeze()
            else:
                grads = torch.concat([grads, l.lora_E[l.adapter_name].grad.squeeze()])
        else:
            raise TypeError()
    return grads


@torch.no_grad()
def update_state(s, low_rank_layers):
    """
    low_rank_layers: list of low_rank layers
    s: global s vector
    returns the list of low_rank layers with l.s updated according to the new s
    """
    i = 0
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear) or isinstance(l, low_rank_CP):
            l.s.copy_(s[i : i + l.s.shape[0]])
            i += l.s.shape[0]
        elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
            l.lora_E.copy_(
                torch.unsqueeze(s[i : i + l.lora_E.shape[0]], 1)
            )
            i += l.lora_E[l.adapter_name].shape[0]
        elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
            l.lora_E[l.adapter_name].copy_(
                torch.unsqueeze(s[i : i + l.lora_E[l.adapter_name].shape[0]], 1)
            )
            i += l.lora_E[l.adapter_name].shape[0]
        else:
            raise TypeError()
    return low_rank_layers


@torch.no_grad()
def freeze_zero_entries_s(low_rank_layers):
    """
    low_rank_layers: list of low_rank layers
    freezes the basis entries for which s_i = 0.0
    """
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear):
            l.us.copy_(torch.einsum("ij,j->ij", l.us, (l.s != 0.0)))
            l.vs.copy_(torch.einsum("ij,j->ij", l.vs, (l.s != 0.0)))
        elif isinstance(l, low_rank_CP):
            for u in l.us:
                u.copy_(torch.einsum("ij,j->ij", u, (l.s != 0.0)))
        elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
            l.lora_B.copy_(
                torch.einsum(
                    "ij,j->ij",
                    l.lora_B,
                    (l.lora_E[:, 0] != 0.0),
                )
            )
            l.lora_A.copy_(
                torch.einsum(
                    "ji,j->ji",
                    l.lora_A,
                    (l.lora_E[:, 0] != 0.0),
                )
            )
        elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
            l.lora_B[l.adapter_name].copy_(
                torch.einsum(
                    "ij,j->ij",
                    l.lora_B[l.adapter_name],
                    (l.lora_E[l.adapter_name][:, 0] != 0.0),
                )
            )
            l.lora_A[l.adapter_name].copy_(
                torch.einsum(
                    "ji,j->ji",
                    l.lora_A[l.adapter_name],
                    (l.lora_E[l.adapter_name][:, 0] != 0.0),
                )
            )


@torch.no_grad()
def in_train_prune(s_it, final_cr, n_steps, it, low_rank_layers):
    """
    s_it : current global s vector to prune
    final_cr : final compression ratio
    n_steps: number of steps in which to fully prune
    it: current iteration
    low_rank_layers: list of low_rank layers
    returns the pruned s_it, updates the state and returns updated low_rank_layers
    """
    s_pruned = prune(s_it, it * torch.tensor(final_cr / n_steps))
    low_rank_layers = update_state(s_pruned, low_rank_layers)
    freeze_zero_entries_s(low_rank_layers)
    return s_pruned, low_rank_layers


@torch.no_grad()
def print_sparsity(s):
    """
    s: global s vector
    prints the sparsity percentage of s
    """
    print(f"sparsity: {torch.sum(s == 0.0)/s.shape[0]}")

@torch.no_grad()
def print_total_params(low_rank_layers):
    """
    s: global s vector
    prints the sparsity percentage of s
    """
    total_params = 0
    for l in low_rank_layers:
        r_l = torch.sum(l.lora_E != 0.0)
        total_params += r_l*(l.lora_B.shape[0]+ l.lora_A.shape[1]+1) 
    print(f"total_params: {total_params}")


@torch.no_grad()
def calculate_max_step_as(s, d_AS, tau, eps):
    """
    calculated the maximal step to take to not exit constraint
    s: current s
    d_AS : away step direction
    tau : constrain ball norm
    """
    print("max length AS search...")
    # positivity_limit = (d_AS<0)*d_AS
    limit_list = [(eps - s_i) / d_i for (s_i, d_i) in zip(s, d_AS) if d_i < 0]
    if len(limit_list) == 0:
        limit_list = [0.0]
    beta = min(
        limit_list
    )  # -s/positivity_limit)
    s_try = s + beta * d_AS
    while torch.sum(torch.abs(s_try)) >= tau:
        # print(f'norm,tau : {torch.sum(torch.abs(s_try)),tau}')
        beta /= 2
        s_try = s + beta * d_AS
    print("done!")
    return beta


@torch.no_grad()
def collect_lr_layers(model, instance_modules):
    adapters_layers = []

    def fn(mod):
        if isinstance(mod, tuple(instance_modules)):
            adapters_layers.append(mod)

    model.apply(fn)
    return adapters_layers


class adalora_bilevel_trainer(Trainer):

    """
    Custom trainer class for the Bilevel low-rank FW optimizer
    """

    def __init__(
        self,
        data_collator,
        train_dataset,
        eval_dataset,
        tokenizer,
        low_rank_instances,
        max_epochs_ll=5,
        tau=10,
        optimizer_and_scheduler = [],
        eps=1e-4,
        riemannian=False,
        final_cr=0.0,
        pruning_steps=10,
        adapter_name = 'default',
        data_args = None,
        **kwargs,
    ):

        super().__init__(**kwargs)
        self.low_rank_layers = collect_lr_layers(self.model, low_rank_instances)
        self.low_rank_layers = wrap_classinstance_with_lr_methods(self.low_rank_layers,[get_hypergradient,activate_lower_level,activate_upper_level],adapter_default_name=adapter_name)#adapter_default_name=list(self.low_rank_layers[0].lora_E.keys())[0])
        s_temp = get_s(self.low_rank_layers)
        #######--------- initial constraint
        with torch.no_grad():
            print(f'initial check constraints {torch.sum(torch.abs(s_temp))<= tau,torch.sum(s_temp<eps)==0}')
            if torch.sum(torch.abs(s_temp))>= tau or torch.sum(s_temp<eps)!=0:
                print(f'constraints not satisfied')
                s_temp/= (torch.sum(torch.abs(s_temp))/(0.5*tau))
                s_temp+= eps
            self.low_rank_layers = update_state(s_temp,self.low_rank_layers)
            print(f'updates s satisfies constraints: {torch.sum(torch.abs(s_temp))<=tau,torch.sum(s_temp<eps)==0}')
        ####### 
        self.problem_dim = s_temp.shape[0]
        self.data_args = data_args
        self.device = s_temp.device
        self.batch_size = self.args.per_device_train_batch_size
        self.max_epochs_ll = max_epochs_ll
        num_update_steps_per_epoch = len(train_dataset) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        self.total_epochs = self.args.num_train_epochs * num_update_steps_per_epoch
        self.riemannian = riemannian
        self.optimizer_UV = optimizer_and_scheduler['optimizer']
        self.scheduler_UV = optimizer_and_scheduler['scheduler']
        set_debug_apis(False)
        print(f"problem dimension: {self.problem_dim}")
        # print(f"pre-training validation metrics: {self.evaluate(self.eval_dataset)}")
        ##### initialization Franke-Wolfe variables #########
        self.mask = torch.ones(
            get_s(self.low_rank_layers).shape[0], device=self.device
        ).bool()
        # self.retract_Oblique()
        self.fw_state = {
            "s_it": s_temp,
            "it": torch.tensor(0, device=self.device, dtype=torch.int32),
            "it_init_prune": torch.tensor(1, device=self.device, dtype=torch.int32),
            "tau": torch.tensor(tau, device=self.device),
            "scaled_tau": tau
            + torch.tensor(self.problem_dim * eps, device=self.device),
            "eps": torch.tensor(eps, device=self.device),
            "tol": 1e-5,
            "final_cr": final_cr,
            "pruning_steps": pruning_steps,
        }
        #### dataset creation
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.tokenizer = tokenizer
        if self.args.fp16:
            if self.args.fp16_backend == "auto":
                self.fp16_backend = "amp" if _is_native_amp_available else "apex"
            else:
                self.fp16_backend = self.args.fp16_backend
            logger.info(f"Using {self.fp16_backend} fp16 backend")
        if self.args.fp16 and not self.args.deepspeed:  # deepspeed manages its own fp16
            if self.fp16_backend == "amp":
                self.use_amp = True
                self.scaler =  torch.cuda.amp.GradScaler()
            else:
                if not is_apex_available():
                    raise ImportError(
                        "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
                    )
                self.use_apex = True
        
        else:
            self.scaler = torch.cuda.amp.GradScaler() 
        self._signature_columns = None
        self.n_batches = len(self.train_dataset)
        if is_datasets_available():
            if isinstance(train_dataset, datasets.Dataset):
                self._remove_unused_columns(self.train_dataset, description="training")
            if isinstance(eval_dataset, datasets.Dataset):
                self._remove_unused_columns(self.eval_dataset, description="evaluation")
        self.train_loader = self.get_train_dataloader()
        self.eval_loader = self.get_eval_dataloader(self.eval_dataset) if self.eval_dataset is not None else None   
        #### initialize accelerator
        self.accelerator = Accelerator()
        self.model, self.optimizer_UV, self.train_loader, self.scheduler_UV = self.accelerator.prepare(
            self.model, self.optimizer_UV, self.train_loader, self.scheduler_UV
            )
        self.initialize_fw()
        wandb.init(
            project=f'blo_{data_args.task_name}',
            config = {'task':data_args.task_name}
        )
        print(f'TRAINER INITIALIZED!\n')

    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += ["label", "label_ids"]
        columns = [k for k in self._signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
        if len(ignored_columns) > 0:
            dset_description = "" if description is None else f"in the {description} set "
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
            )

        dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"])

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None

        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

    @torch.no_grad()
    def get_Rgrad_Oblique(self):
        """
        New function, give the euclidean gradients computes the Riemannian gradients on the oblique manifold w.r.t the ambient metric (equivalent to product metric of spheres)
        """
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                mask = (l.s != 0.0).bool()   ###.byte() deprecated
                l.us.grad[:, mask].copy_(
                    l.us.grad[:, mask]
                    - l.us[:, mask]
                    @ torch.diag(torch.diag(l.us[:, mask].T @ l.us.grad[:, mask]))
                )
                l.vs.grad[:, mask].copy_(
                    l.vs.grad[:, mask]
                    - l.vs[:, mask]
                    @ torch.diag(torch.diag(l.vs[:, mask].T @ l.vs.grad[:, mask]))
                )
            elif isinstance(l, low_rank_CP):
                mask = (l.s != 0.0).bool()
                for u in l.us:
                    u.grad[:, mask].copy_(
                        u.grad[:, mask]
                        - u[:, mask]
                        @ torch.diag(torch.diag(u[:, mask].T @ u.grad[:, mask]))
                    )
            elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
                mask = (l.lora_E.squeeze() != 0.0).bool()
                l.lora_B.grad[:, mask].copy_(
                    l.lora_B.grad[:, mask]
                    - l.lora_B[:, mask]
                    @ torch.diag(
                        torch.diag(
                            l.lora_B[:, mask].T
                            @ l.lora_B.grad[:, mask]
                        )
                    )
                )
                l.lora_A.grad[mask, :].copy_(
                    l.lora_A.grad[mask, :]
                    - torch.diag(
                        torch.diag(
                            l.lora_A[mask, :]
                            @ l.lora_A.grad[mask, :].T
                        )
                    )
                    @ l.lora_A[mask, :]
                )
            elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
                adapter_name = l.adapter_name
                mask = (l.lora_E[adapter_name].squeeze() != 0.0).bool()
                l.lora_B[adapter_name].grad[:, mask].copy_(
                    l.lora_B[adapter_name].grad[:, mask]
                    - l.lora_B[adapter_name][:, mask]
                    @ torch.diag(
                        torch.diag(
                            l.lora_B[adapter_name][:, mask].T
                            @ l.lora_B[adapter_name].grad[:, mask]
                        )
                    )
                )
                l.lora_A[adapter_name].grad[mask, :].copy_(
                    l.lora_A[adapter_name].grad[mask, :]
                    - torch.diag(
                        torch.diag(
                            l.lora_A[adapter_name][mask, :]
                            @ l.lora_A[adapter_name].grad[mask, :].T
                        )
                    )
                    @ l.lora_A[adapter_name][mask, :]
                )
            else:
                raise TypeError()

    @torch.no_grad()
    def get_landing_field_stiefel(self, lamb=1e-2):
        """
        New function, give the euclidean gradients computes Landing field on product of stiefel manifolds
        """
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                l.us.grad.copy_(
                    l.us.grad
                    - 0.5 * l.us @ (l.us.T @ l.us.grad + l.us.grad.T @ l.us)
                    + lamb * (l.us @ l.us.T @ l.us - l.us)
                )
                l.vs.grad.copy_(
                    l.vs.grad
                    - 0.5 * l.vs @ (l.vs.T @ l.vs.grad + l.vs.grad.T @ l.vs)
                    + lamb * (l.vs @ l.vs.T @ l.vs - l.vs)
                )
            elif isinstance(l, low_rank_CP):
                for u in l.us:
                    u.grad.copy_(
                        u.grad
                        - 0.5 * u @ (u.T @ u.grad + u.grad.T @ u)
                        + lamb * (u @ u.T @ u - u)
                    )
            elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
                l.lora_B.grad.copy_(
                    l.lora_B.grad
                    - 0.5
                    * l.lora_B
                    @ (
                        l.lora_B.T @ l.lora_B.grad
                        + l.lora_B.grad.T @ l.lora_B
                    )
                    + lamb
                    * (
                          l.lora_B
                        @ l.lora_B.T
                        @ l.lora_B
                        - l.lora_B
                    )
                )
                l.lora_A.grad.copy_(
                    l.lora_A.grad
                    - 0.5
                    * l.lora_A
                    @ (
                        l.lora_A.T @ l.lora_A.grad
                        + l.lora_A.grad.T @ l.lora_A
                    )
                    + lamb
                    * (
                          l.lora_A
                        @ l.lora_A.T
                        @ l.lora_A
                        - l.lora_A
                    )
                )
            elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
                adapter_name = l.adapter_name
                l.lora_B[adapter_name].grad.copy_(
                    l.lora_B[adapter_name].grad
                    - 0.5
                    * l.lora_B[adapter_name]
                    @ (
                        l.lora_B[adapter_name].T @ l.lora_B[adapter_name].grad
                        + l.lora_B[adapter_name].grad.T @ l.lora_B[adapter_name]
                    )
                    + lamb
                    * (
                        l.lora_B[adapter_name]
                        @ l.lora_B[adapter_name].T
                        @ l.lora_B[adapter_name]
                        - l.lora_B[adapter_name]
                    )
                )
                l.lora_A[adapter_name].grad.copy_(
                    l.lora_A[adapter_name].grad
                    - 0.5
                    * l.lora_A[adapter_name]
                    @ (
                        l.lora_A[adapter_name].T @ l.lora_A[adapter_name].grad
                        + l.lora_A[adapter_name].grad.T @ l.lora_A[adapter_name]
                    )
                    + lamb
                    * (
                        l.lora_A[adapter_name]
                        @ l.lora_A[adapter_name].T
                        @ l.lora_A[adapter_name]
                        - l.lora_A[adapter_name]
                    )
                )
            else:
                raise TypeError()

    @torch.no_grad()
    def retract_Oblique(self):
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                mask = (l.s != 0.0).bool()
                l.us[:, mask] = l.us[:, mask] / torch.linalg.norm(l.us[:, mask], dim=0)
                l.vs[:, mask] = l.vs[:, mask] / torch.linalg.norm(l.vs[:, mask], dim=0)
            elif isinstance(l, low_rank_CP):
                mask = (l.s != 0.0).bool()
                for u in l.us:
                    u[:, mask] = u[:, mask] / torch.linalg.norm(u[:, mask], dim=0)
            elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
                mask = (l.lora_E.squeeze() != 0.0).bool()
                l.lora_B[:, mask] = l.lora_B[
                    :, mask
                ] / torch.linalg.norm(l.lora_B[:, mask], dim=0)
                l.lora_A[mask, :] = (
                    l.lora_A[mask,:].T
                    / torch.linalg.norm(l.lora_A[mask, :], dim=1)
                ).T
            elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
                adapter_name = l.adapter_name
                mask = (l.lora_E[adapter_name].squeeze() != 0.0).bool()
                l.lora_B[adapter_name][:, mask] = l.lora_B[adapter_name][
                    :, mask
                ] / torch.linalg.norm(l.lora_B[adapter_name][:, mask], dim=0)
                l.lora_A[adapter_name][mask, :] = (
                    l.lora_A[adapter_name][mask,:].T
                    / torch.linalg.norm(l.lora_A[adapter_name][mask, :], dim=1)
                ).T
            else:
                raise TypeError()

    @torch.no_grad()
    def retract_Stiefel(self):
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                l.us.copy_(torch.linalg.qr(l.us, mode="reduced")[0])
                l.vs.copy_(torch.linalg.qr(l.vs, mode="reduced")[0])
            elif isinstance(l, low_rank_CP):
                for u in l.us:
                    u.copy_(torch.linalg.qr(u, mode="reduced")[0])
            elif isinstance(l, AdaLoraLayer) and l.adapter_name==None:
                l.lora_B.copy_(
                    torch.linalg.qr(l.lora_B, mode="reduced")[0]
                )
                l.lora_A.copy_(
                    torch.linalg.qr(l.lora_A.T, mode="reduced")[0].T
                )
            elif isinstance(l, AdaLoraLayer) and hasattr(l,'adapter_name'):
                adapter_name = l.adapter_name
                l.lora_B[adapter_name].copy_(
                    torch.linalg.qr(l.lora_B[adapter_name], mode="reduced")[0]
                )
                l.lora_A[adapter_name].copy_(
                    torch.linalg.qr(l.lora_A[adapter_name].T, mode="reduced")[0].T
                )
            else:
                raise TypeError()

    @torch.no_grad()
    def full_loss(self):
        """
        Computes and returns full loss value
        """
        total_loss = 0.0
        total_accuracy = 0.0
        for batch in self.train_dataset:
            with torch.cuda.amp.autocast():
                loss = self.compute_loss(self.model, batch)
            total_loss += float(loss.item()) / (self.n_batches)
        return total_loss, total_accuracy

    @torch.no_grad()
    def armijo_line_search(
        self, d_it, s_it, fun_s_it, grad_s_it, start_step=1.0, batch=None
    ):
        """
        Stochastic armijo line search on the batch databatch. Returns optimal step and optimal point
        d: descent direction
        s_it : current evaluation point
        ts: flax train state (updated during armijo and returned at the end)
        basis: current basis (fixed) during the search
        L : objective function
        data_batch : batch of data on which the loss is averaged for the line search
        fun_s_it : L(s_it,basis,data_batch)
        grad_s_it : hypergradient_s L(s_it,basis,databatch)
        Line search converges, but add a counter just for numerical stability
        """
        step = start_step
        s_it_d = s_it + step * d_it
        fun_s_it_d = self.batch_loss(batch)  # self.full_loss()
        counter = 0
        while (
            fun_s_it_d - fun_s_it >= 1e-1 * step * torch.dot(d_it, grad_s_it)
            and counter <= 15
        ):
            step = step / 2
            s_it_d = s_it + step * d_it  # update
            self.low_rank_layers = update_state(s_it_d, self.low_rank_layers)
            fun_s_it_d = self.batch_loss(batch)  # self.full_loss()
            counter += 1
        return step, s_it_d

    @torch.no_grad()
    def batch_loss(self, data):
        if data is not None:
            return self.compute_loss(self.model, data)
        else:
            return self.full_loss()

    def get_U_star(self):
        """
        Computes U_star(s) for the hypergradient calculation
        """
        self.model.train()
        for layer in self.low_rank_layers:
            layer.activate_lower_level()
        for _ in range(self.max_epochs_ll):
            loss_fn = 0.0
            for batch in self.train_loader:
                batch = self._prepare_inputs(batch)
                for p in self.model.parameters():
                    if p.requires_grad:
                        p.grad = None
                with torch.cuda.amp.autocast():
                    loss = self.compute_loss(self.model, batch)
                # self.scaler.scale(loss).backward()
                self.accelerator.backward(loss)
                loss_fn += float(loss.item())
                if self.riemannian:
                    if self.riemannian.lower() == "oblique":
                        self.get_Rgrad_Oblique()
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                        self.retract_Oblique()
                    elif self.riemannian.lower() == "stiefel_landing":
                        self.get_landing_field_stiefel()
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                    elif self.riemannian.lower() == "stiefel":
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                        self.retract_Stiefel()
                else:
                    # self.scaler.step(self.optimizer_UV)
                    self.optimizer_UV.step()
                # self.scaler.update()
            self.scheduler_UV.step()  # (total_loss)   ## no use of scheduler for the moment
        self.model.eval()
        print(f"finish lower level with loss: {loss_fn}")

    def create_datasets(self):
        self.train_loader = self.get_train_dataloader()
        self.eval_loader = self.get_eval_dataloader(self.eval_dataset) if self.eval_dataset is not None else None

    def get_hypergradient(self, model, inputs, already_optimal=False):
        """
        Computes full hypergradient nabla_s L^* on the whole dataset
        """
        if not already_optimal:
            self.get_U_star()
            # self.train_loader = self.get_train_dataloader()
        for layer in self.low_rank_layers:
            layer.activate_upper_level()
        for p in model.parameters():
            if p.requires_grad:
                p.grad = None
        loss_fn = 0.0
        accuracy_fn = 0.0
        with torch.cuda.amp.autocast():
            loss = self.compute_loss(self.model, inputs)
        # self.scaler.scale(loss).backward()  # loss.backward()
        self.accelerator.backward(loss)
        with torch.no_grad():
            loss_fn += float(loss.item())
        for l in self.low_rank_layers:
            l.get_hypergradient()
        print(f"calculated U^*, metrics: {loss_fn,accuracy_fn}")
        self.fw_state["current_batch"] = inputs
        h_grad = get_grad_s(self.low_rank_layers)
        for p in self.model.parameters():
            if p.requires_grad:
                p.grad = None
        return loss_fn, None, h_grad

    @torch.no_grad()
    def initialize_fw(self):
        s_it = get_s(self.low_rank_layers)
        self.fw_state["s_it"] = s_it
        inputs = next(iter(self.train_loader))
        inputs = self._prepare_inputs(inputs)
        self.fw_state["current_batch"] = inputs
        #### hypergradient calculation
        with torch.enable_grad():
            fun_s_it, _, grad_s_it = self.get_hypergradient(self.model,inputs)
        # print(f's satisfied constraints zero {torch.sum(torch.abs(s_it))<=self.fw_state['scaled_tau'],torch.sum(s_it<self.fw_state['eps'])==0}')
        # minimize linearize function
        # unit simplex #vector of the +- canonical basis corrispondent to min index gradient
        index_min = torch.argmin(grad_s_it)
        index_max = torch.argmax(grad_s_it * torch.abs(torch.sign(s_it)))
        sign_min = (
            self.fw_state["tau"] - self.fw_state["eps"]
            if torch.sign(grad_s_it[index_min]) < 0
            else self.fw_state["eps"]
        )  # 0.
        sign_max = (
            self.fw_state["tau"] - self.fw_state["eps"]
            if torch.sign(grad_s_it[index_max]) > 0
            else self.fw_state["eps"]
        )  # 0.
        s_it_min = self.fw_state["eps"] * torch.ones(s_it.shape[0], device=s_it.device)
        s_it_min[index_min].copy_(sign_min)
        s_it_max = self.fw_state["eps"] * torch.ones(
            s_it.shape[0], device=s_it.device
        )  ##eps instead of zero
        s_it_max[index_max].copy_(sign_max)
        # print(f's+,s- satisfies constraints {torch.sum(torch.abs(s_it_max))<=self.fw_state['scaled_tau'],torch.sum(torch.abs(s_it_min))<=self.fw_state['scaled_tau']}')
        d_AS = s_it - s_it_max
        d_FW = s_it_min - s_it
        # descent direction
        if torch.dot(grad_s_it, d_FW) <= torch.dot(grad_s_it, d_AS):
            d_it = d_FW
            step = 1.0
        else:
            d_it = d_AS
            step = calculate_max_step_as(
                s_it, d_it, tau=self.fw_state["tau"], eps=self.fw_state["eps"]
            )
        self.fw_state["step"] = step
        self.fw_state["d_it"] = d_it
        self.fw_state["fun_s_it"] = fun_s_it
        self.fw_state["grad_s_it"] = grad_s_it

    @torch.no_grad()
    def training_step(
        self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        
        #### evaluate inside training ###
        if self.fw_state["it"]%5 == 0:
            self.print_evaluation()
        ###############################

        s_it, grad_s_it, d_it,fun_s_it,step = (
            self.fw_state["s_it"],
            self.fw_state["grad_s_it"],
            self.fw_state["d_it"],
            self.fw_state["fun_s_it"],
            self.fw_state["step"]
        )
        inputs = self._prepare_inputs(inputs)
        if (
            (self.fw_state["it"] < self.total_epochs
            and torch.dot(grad_s_it, d_it) <= -self.fw_state["tol"]) or self.fw_state["it"] <= self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1
        ):
            self.fw_state["it"] += 1
            ####### loss cumulator
            # Armijo line search
            print(f"armijo...")
            step, s_it = self.armijo_line_search(
                d_it,
                s_it,
                fun_s_it,
                grad_s_it,
                start_step=step,
                batch=self.fw_state["current_batch"],
            )
            print(f"end armijo!")
            #### update state
            self.low_rank_layers = update_state(s_it, self.low_rank_layers)
            ### pruning
            if (
                self.fw_state["it"] >= self.fw_state["it_init_prune"]
                and self.fw_state["it"]
                < self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1
            ):
                print(f"in train prune...")
                with torch.no_grad():
                    s_it, self.low_rank_layers = in_train_prune(
                        s_it,
                        self.fw_state["final_cr"],
                        n_steps=self.fw_state["pruning_steps"],
                        it=self.fw_state["it"] - self.fw_state["it_init_prune"],
                        low_rank_layers=self.low_rank_layers,
                    )
                    self.mask = (s_it != 0.0).bool()
                    print_sparsity(s_it)
                    print_total_params(self.low_rank_layers)
            # gradient - approximation and new loss value
            with torch.enable_grad():
                fun_s_it, _, grad_s_it = self.get_hypergradient(
                    model=model, inputs=inputs
                )
            # minimize linearize function
            # unit simplex #vector of the +- canonical basis corrispondent to min index gradient
            index_min = torch.argmin(grad_s_it)
            index_max = torch.argmax(grad_s_it * torch.abs(torch.sign(s_it)))
            sign_min = (
                self.fw_state["tau"] - self.fw_state["eps"]
                if torch.sign(grad_s_it[index_min]) < 0
                else self.fw_state["eps"]
            )  # 0.
            sign_max = (
                self.fw_state["tau"] - self.fw_state["eps"]
                if torch.sign(grad_s_it[index_max]) > 0
                else self.fw_state["eps"]
            )  # 0.
            s_it_min = self.fw_state["eps"] * torch.ones(
                s_it.shape[0], device=s_it.device
            )  ###TODO: to keep checked in which device do these tensors go
            s_it_min[index_min].copy_(sign_min)
            s_it_max = self.fw_state["eps"] * torch.ones(
                s_it.shape[0], device=s_it.device
            )
            s_it_max[index_max].copy_(sign_max)
            s_it_min.mul_(self.mask)
            s_it_max.mul_(self.mask)

            d_AS = s_it - s_it_max
            d_FW = s_it_min - s_it
            # descent direction
            if torch.dot(grad_s_it, d_FW) <= torch.dot(grad_s_it, d_AS):
                d_it = d_FW
                step = 1.0
            else:
                d_it = d_AS
                step = calculate_max_step_as(
                    s_it, d_it, self.fw_state["tau"], self.fw_state["eps"]
                )
            print(
                f"s satisfied constraints {torch.sum(torch.abs(s_it))<=self.fw_state['scaled_tau'],torch.sum(s_it<0.0)==0}"
            )
            print("-" * 100)
            #### updata fw state
            self.fw_state["current_batch"] = inputs
            self.fw_state["s_it"] = s_it
            self.fw_state["grad_s_it"] = grad_s_it
            self.fw_state["d_it"] = d_it
            self.fw_state["step"] = step
            self.fw_state["fun_s_it"] = fun_s_it
            print(f'grad norm, d_it norm {torch.linalg.norm(grad_s_it),torch.linalg.norm(d_it)}')
            return fun_s_it
        else:
            print(f'converged! grad norm, d_it norm {torch.linalg.norm(grad_s_it),torch.linalg.norm(d_it)}')
            if self.fw_state["it"] >= self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1:  ### let the pruning end and then stop training
                self.control.should_epoch_stop = True
                wandb.finish()
            return self.fw_state["fun_s_it"]
        
    @torch.no_grad()
    def print_evaluation(self):
        self.model.eval()
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [self.data_args.task_name]
        eval_datasets = [self.eval_dataset]
        if self.data_args.task_name == "mnli":
            tasks.append("mnli-mm")
            eval_datasets.append(datasets["validation_mismatched"])

        wandb_logs = dict()
        for eval_dataset, task in zip(eval_datasets, tasks):
            metrics = self.evaluate(eval_dataset=eval_dataset)

            max_val_samples = self.data_args.max_val_samples if self.data_args.max_val_samples is not None else len(eval_dataset)
            metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
            for key in metrics:
                if self.tb_writter:
                    self.tb_writter.add_scalar("Eval_%s/%s"%(task, key), metrics[key], self.args.num_train_epochs)
                
                logger.info("{task} {key}: {value}:".format(task=task, key=key, value=metrics[key]))
            self.log_metrics("Eval_%s"%task, metrics)
            self.save_metrics("Eval_%s"%task, metrics)
            wandb_logs[task] = metrics
        wandb.log(wandb_logs)
    