#%%
import torch, pickle
from peft_utils.peft_layers import low_rank_linear
from peft_utils.peft_layers import low_rank_CP
# from peft.tuners.adalora.layer import AdaLoraLayer
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import inspect
import torch.nn.functional as F

from accelerate import Accelerator

import wandb


def set_debug_apis(state: bool = False):
    torch.autograd.profiler.profile(enabled=state)
    torch.autograd.profiler.emit_nvtx(enabled=state)
    torch.autograd.set_detect_anomaly(mode=state)

from packaging import version
if version.parse(torch.__version__) >= version.parse("1.6"):
    _is_native_amp_available = True
    from torch.cuda.amp import autocast


@torch.no_grad()
def basis_norm_check(low_rank_layers):
    """
    Check feasibility of basis in the Oblique manifold
    """
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear):
            mask = (l.s != 0.0).bool()
            print(
                f"check U,V distance to constraint: {torch.linalg.norm(torch.linalg.norm(l.us[:,mask],dim = 0)-1),torch.linalg.norm(torch.linalg.norm(l.us[:,mask],dim = 0)-1)}"
            )
        elif isinstance(l, low_rank_CP):
            mask = (l.s != 0.0).bool()
            print(
                f"check us distance to constraint: {[torch.linalg.norm(torch.linalg.norm(u[:,mask],dim = 0)-1) for u in l.us]}"
            )


@torch.no_grad()
def stiefel_constraint_check(low_rank_layers):
    """
    Check feasibility of basis in the Oblique manifold
    """
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear):
            print(
                f"check U,V distance to constraint: {torch.linalg.norm(l.us.T@l.us-torch.eye(l.us.shape[1],device = l.us.device)),torch.linalg.norm(l.vs.T@l.vs-torch.eye(l.vs.shape[1],device = l.vs.device))}"
            )
        elif isinstance(l, low_rank_CP):
            print(
                f"check us distance to constraint: {[torch.linalg.norm(u.T@u-torch.eye(u.shape[1],device = u.device)) for u in l.us]}"
            )
        


def load_trainer(path):
    file = open(path, "rb")
    t = pickle.load(file)
    return t


@torch.no_grad()
def prune(s, cr, s_measure=None):
    """
    Prunes the cr% of smallest entries of s and returns the pruned s
    s: global s vector
    cr: global neural network compression ratio
    """
    n_to_cut = min([int(torch.ceil((cr) * len(s))), len(s) - 1])
    if s_measure != None:
        s_measure_copy = s_measure.clone()
    else:
        s_measure_copy = s.clone()
    s_copy = s.clone()
    s_sort, _ = torch.sort(s_measure_copy, descending=False)
    s_pruned = (s_measure_copy > s_sort[n_to_cut]) * s_copy
    return s_pruned

@torch.no_grad()
def get_s(low_rank_layers):
    """
    low_rank_layers: list of low_rank layers
    returns the vectorized version of global s
    """
    s = None
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear) or isinstance(l, low_rank_CP):
            if s == None:
                s = l.s
            else:
                s = torch.concat([s, l.s])
        else:
            raise TypeError()
    return s

@torch.no_grad()
def get_grad_s(low_rank_layers):
    """
    low_rank_layers: list of low_rank layers
    returns the vectorized version of grad_s
    """
    grads = None
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear) or isinstance(l, low_rank_CP):
            if grads == None:
                grads = l.s.grad
            else:
                grads = torch.concat([grads, l.s.grad])
        else:
            raise TypeError()
    return grads


@torch.no_grad()
def update_state(s, low_rank_layers):
    """
    low_rank_layers: list of low_rank layers
    s: global s vector
    returns the list of low_rank layers with l.s updated according to the new s
    """
    i = 0
    for l in low_rank_layers:
        if isinstance(l, low_rank_linear) or isinstance(l, low_rank_CP):
            l.s.copy_(s[i : i + l.s.shape[0]])
            i += l.s.shape[0]
        else:
            raise TypeError()
    return low_rank_layers


@torch.no_grad()
def freeze_zero_entries_s(low_rank_layers,riemannian = ''):
    """
    low_rank_layers: list of low_rank layers
    freezes the basis entries for which s_i = 0.0
    """
    if riemannian == '':
        for l in low_rank_layers:
            if isinstance(l, low_rank_linear):
                l.us.copy_(torch.einsum("ij,j->ij", l.us, (l.s != 0.0)))
                l.vs.copy_(torch.einsum("ij,j->ij", l.vs, (l.s != 0.0)))
            elif isinstance(l, low_rank_CP):
                for u in l.us:
                    u.copy_(torch.einsum("ij,j->ij", u, (l.s != 0.0)))


@torch.no_grad()
def in_train_prune(s_it, final_cr, n_steps, it, low_rank_layers,riemannian = ''):
    """
    s_it : current global s vector to prune
    final_cr : final compression ratio
    n_steps: number of steps in which to fully prune
    it: current iteration
    low_rank_layers: list of low_rank layers
    returns the pruned s_it, updates the state and returns updated low_rank_layers
    """
    s_pruned = prune(s_it, it * torch.tensor(final_cr / n_steps))
    low_rank_layers = update_state(s_pruned, low_rank_layers)
    freeze_zero_entries_s(low_rank_layers,riemannian)
    return s_pruned, low_rank_layers


@torch.no_grad()
def print_sparsity(s):
    """
    s: global s vector
    prints the sparsity percentage of s
    """
    print(f"sparsity: {torch.sum(s == 0.0)/s.shape[0]}")

@torch.no_grad()
def print_total_params(low_rank_layers):
    """
    s: global s vector
    prints the sparsity percentage of s
    """
    total_params = 0
    for l in low_rank_layers:
        if isinstance(l,low_rank_linear):
            r_l = torch.sum(l.s != 0.0)
            total_params += r_l*(l.us.shape[0]+ l.vs.shape[0]+1) 
        elif isinstance(l,low_rank_CP):
            r_l = torch.sum(l.s != 0.0)
            total_params += r_l*(sum([u.shape[0] for u in l.us])+1) 
    print(f"total_params: {total_params}")


@torch.no_grad()
def calculate_max_step_as(s, d_AS, tau, eps):
    """
    calculated the maximal step to take to not exit constraint
    s: current s
    d_AS : away step direction
    tau : constrain ball norm
    """
    print("max length AS search...")
    # positivity_limit = (d_AS<0)*d_AS
    limit_list = [(eps - s_i) / d_i for (s_i, d_i) in zip(s, d_AS) if d_i < 0]
    if len(limit_list) == 0:
        limit_list = [0.0]
    beta = min(
        limit_list
    )  # -s/positivity_limit)
    s_try = s + beta * d_AS
    while torch.sum(torch.abs(s_try)) >= tau:
        # print(f'norm,tau : {torch.sum(torch.abs(s_try)),tau}')
        beta /= 2
        s_try = s + beta * d_AS
    print("done!")
    return beta


@torch.no_grad()
def collect_lr_layers(model, instance_modules):
    adapters_layers = []

    def fn(mod):
        if isinstance(mod, tuple(instance_modules)):
            adapters_layers.append(mod)

    model.apply(fn)
    return adapters_layers


class adalora_bilevel_trainer:

    """
    Custom trainer class for the Bilevel low-rank FW optimizer
    """

    def __init__(
        self,
        model,
        train_dataset,
        eval_dataset,
        low_rank_layers,
        max_epochs_ll=2,
        tau=10,
        optimizer_and_scheduler = [],
        eps=1e-4,
        riemannian=False,
        final_cr=0.0,
        pruning_steps=10,
        task_name = 'resnet',
        epochs = 10
    ):

        super(adalora_bilevel_trainer).__init__()
        self.model = model
        self.low_rank_layers = low_rank_layers
        s_temp = get_s(self.low_rank_layers)
        #######--------- initial constraint
        with torch.no_grad():
            print(f'initial check constraints {torch.sum(torch.abs(s_temp))<= tau,torch.sum(s_temp<eps)==0}')
            if torch.sum(torch.abs(s_temp))>= tau or torch.sum(s_temp<eps)!=0:
                print(f'constraints not satisfied')
                s_temp/= (torch.sum(torch.abs(s_temp))/(0.5*tau))
                s_temp+= eps
            self.low_rank_layers = update_state(s_temp,self.low_rank_layers)
            print(f'updates s satisfies constraints: {torch.sum(torch.abs(s_temp))<=tau,torch.sum(s_temp<eps)==0}')
        ####### 
        self.criterion = torch.nn.CrossEntropyLoss()
        self.problem_dim = s_temp.shape[0]
        self.device = s_temp.device
        self.batch_size = train_dataset.batch_size
        self.max_epochs_ll = max_epochs_ll
        num_update_steps_per_epoch = len(train_dataset)
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        self.total_epochs = epochs * num_update_steps_per_epoch
        self.riemannian = riemannian
        self.optimizer_UV = optimizer_and_scheduler['optimizer']
        self.scheduler_UV = optimizer_and_scheduler['scheduler']
        self.total_epochs = epochs ########### fix teìhe number of fw
        self.should_epoch_stop = False
        set_debug_apis(False)
        print(f"problem dimension: {self.problem_dim}")
        # print(f"pre-training validation metrics: {self.evaluate(self.eval_dataset)}")
        ##### initialization Franke-Wolfe variables #########
        self.mask = torch.ones(
            get_s(self.low_rank_layers).shape[0], device=self.device
        ).bool()
        # self.retract_Oblique()
        self.fw_state = {
            "s_it": s_temp,
            "it": torch.tensor(0, device=self.device, dtype=torch.int32),
            "it_init_prune": torch.tensor(1, device=self.device, dtype=torch.int32),
            "tau": torch.tensor(tau, device=self.device),
            "scaled_tau": tau
            + torch.tensor(self.problem_dim * eps, device=self.device),
            "eps": torch.tensor(eps, device=self.device),
            "tol": 1e-5,
            "final_cr": final_cr,
            "pruning_steps": pruning_steps,
        }
        #### dataset creation
        self.train_loader = train_dataset
        self.val_loader = eval_dataset
        #### initialize accelerator
        self.accelerator = Accelerator(device_placement=False)
        self.model, self.optimizer_UV, self.train_loader, self.scheduler_UV = self.accelerator.prepare(
            self.model, self.optimizer_UV, self.train_loader, self.scheduler_UV
            )
        self.initialize_fw()
        wandb.init(
            project=f'blo_{task_name}',
            config = {'task':task_name}
        )
        print(f'TRAINER INITIALIZED!\n')
        print(f'training for {self.total_epochs} franke-wolfe steps')

    @torch.no_grad()
    def get_Rgrad_Oblique(self):
        """
        New function, give the euclidean gradients computes the Riemannian gradients on the oblique manifold w.r.t the ambient metric (equivalent to product metric of spheres)
        """
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                mask = (l.s != 0.0).bool()   ###.byte() deprecated
                l.us.grad[:, mask].copy_(
                    l.us.grad[:, mask]
                    - l.us[:, mask]
                    @ torch.diag(torch.diag(l.us[:, mask].T @ l.us.grad[:, mask]))
                )
                l.vs.grad[:, mask].copy_(
                    l.vs.grad[:, mask]
                    - l.vs[:, mask]
                    @ torch.diag(torch.diag(l.vs[:, mask].T @ l.vs.grad[:, mask]))
                )
            elif isinstance(l, low_rank_CP):
                mask = (l.s != 0.0).bool()
                for u in l.us:
                    u.grad[:, mask].copy_(
                        u.grad[:, mask]
                        - u[:, mask]
                        @ torch.diag(torch.diag(u[:, mask].T @ u.grad[:, mask]))
                    )
            else:
                raise TypeError()

    @torch.no_grad()
    def get_landing_field_stiefel(self, lamb=1e-2):
        """
        New function, give the euclidean gradients computes Landing field on product of stiefel manifolds
        """
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                l.us.grad.copy_(
                    l.us.grad
                    - 0.5 * l.us @ (l.us.T @ l.us.grad + l.us.grad.T @ l.us)
                    + lamb * (l.us @ l.us.T @ l.us - l.us)
                )
                l.vs.grad.copy_(
                    l.vs.grad
                    - 0.5 * l.vs @ (l.vs.T @ l.vs.grad + l.vs.grad.T @ l.vs)
                    + lamb * (l.vs @ l.vs.T @ l.vs - l.vs)
                )
            elif isinstance(l, low_rank_CP):
                for u in l.us:
                    u.grad.copy_(
                        u.grad
                        - 0.5 * u @ (u.T @ u.grad + u.grad.T @ u)
                        + lamb * (u @ u.T @ u - u)
                    )
            else:
                raise TypeError()

    @torch.no_grad()
    def retract_Oblique(self):
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                mask = (l.s != 0.0).bool()
                l.us[:, mask] = l.us[:, mask] / torch.linalg.norm(l.us[:, mask], dim=0)
                l.vs[:, mask] = l.vs[:, mask] / torch.linalg.norm(l.vs[:, mask], dim=0)
            elif isinstance(l, low_rank_CP):
                mask = (l.s != 0.0).bool()
                for u in l.us:
                    u[:, mask] = u[:, mask] / torch.linalg.norm(u[:, mask], dim=0)
            else:
                raise TypeError()

    @torch.no_grad()
    def retract_Stiefel(self):
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                l.us.copy_(torch.linalg.qr(l.us, mode="reduced")[0])
                l.vs.copy_(torch.linalg.qr(l.vs, mode="reduced")[0])
            elif isinstance(l, low_rank_CP):
                for u in l.us:
                    u[:,:min(u.shape[0],u.shape[1])].copy_(torch.linalg.qr(u, mode="reduced")[0])
            else:
                raise TypeError()
            
    def compute_loss(self,model,batch):
        inputs,labels = batch
        inputs = inputs.to(self.device)
        labels = labels.to(self.device)
        return self.criterion(model(inputs),labels)


    @torch.no_grad()
    def armijo_line_search(
        self, d_it, s_it, fun_s_it, grad_s_it, start_step=1.0, batch=None
    ):
        """
        Stochastic armijo line search on the batch databatch. Returns optimal step and optimal point
        d: descent direction
        s_it : current evaluation point
        ts: flax train state (updated during armijo and returned at the end)
        basis: current basis (fixed) during the search
        L : objective function
        data_batch : batch of data on which the loss is averaged for the line search
        fun_s_it : L(s_it,basis,data_batch)
        grad_s_it : hypergradient_s L(s_it,basis,databatch)
        Line search converges, but add a counter just for numerical stability
        """
        step = start_step
        s_it_d = s_it + step * d_it
        fun_s_it_d = self.batch_loss(batch)  # self.full_loss()
        counter = 0
        while (
            fun_s_it_d - fun_s_it >= 1e-1 * step * torch.dot(d_it, grad_s_it)
            and counter <= 15
        ):
            step = step / 2
            s_it_d = s_it + step * d_it  # update
            self.low_rank_layers = update_state(s_it_d, self.low_rank_layers)
            fun_s_it_d = self.batch_loss(batch)  # self.full_loss()
            counter += 1
        return step, s_it_d

    @torch.no_grad()
    def batch_loss(self, data):
        if data is not None:
            return self.compute_loss(self.model, data)
        else:
            return self.full_loss()

    def get_U_star(self):
        """
        Computes U_star(s) for the hypergradient calculation
        """
        self.model.train()
        for layer in self.low_rank_layers:
            layer.activate_lower_level()
        for _ in range(self.max_epochs_ll):
            loss_fn = 0.0
            for batch in self.train_loader:
                for p in self.model.parameters():
                    if p.requires_grad:
                        p.grad = None
                with torch.cuda.amp.autocast():
                    loss = self.compute_loss(self.model, batch)
                # self.scaler.scale(loss).backward()
                self.accelerator.backward(loss)
                loss_fn += float(loss.item())
                if self.riemannian:
                    if self.riemannian.lower() == "oblique":
                        self.get_Rgrad_Oblique()
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                        self.retract_Oblique()
                    elif self.riemannian.lower() == "stiefel_landing":
                        self.get_landing_field_stiefel()
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                    elif self.riemannian.lower() == "stiefel":
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                        self.retract_Stiefel()
                else:
                    # self.scaler.step(self.optimizer_UV)
                    self.optimizer_UV.step()
                # self.scaler.update()
            self.scheduler_UV.step()  # (total_loss)   ## no use of scheduler for the moment
        self.model.eval()
        print(f"finish lower level with loss: {loss_fn}")


    def get_hypergradient(self, model, inputs, already_optimal=False):
        """
        Computes full hypergradient nabla_s L^* on the whole dataset
        """
        if not already_optimal:
            self.get_U_star()
            # self.train_loader = self.get_train_dataloader()
        for layer in self.low_rank_layers:
            layer.activate_upper_level()
        for p in model.parameters():
            if p.requires_grad:
                p.grad = None
        loss_fn = 0.0
        accuracy_fn = 0.0
        with torch.cuda.amp.autocast():
            loss = self.compute_loss(self.model, inputs)
        # self.scaler.scale(loss).backward()  # loss.backward()
        self.accelerator.backward(loss)
        with torch.no_grad():
            loss_fn += float(loss.item())
        for l in self.low_rank_layers:
            l.get_hypergradient()
        print(f"calculated U^*, metrics: {loss_fn,accuracy_fn}")
        self.fw_state["current_batch"] = inputs
        h_grad = get_grad_s(self.low_rank_layers)
        for p in self.model.parameters():
            if p.requires_grad:
                p.grad = None
        return loss_fn, None, h_grad

    @torch.no_grad()
    def initialize_fw(self):
        s_it = self.fw_state["s_it"]
        inputs = next(iter(self.train_loader))
        self.fw_state["current_batch"] = inputs
        #### hypergradient calculation
        with torch.enable_grad():
            fun_s_it, _, grad_s_it = self.get_hypergradient(self.model,inputs)
        # print(f's satisfied constraints zero {torch.sum(torch.abs(s_it))<=self.fw_state['scaled_tau'],torch.sum(s_it<self.fw_state['eps'])==0}')
        # minimize linearize function
        # unit simplex #vector of the +- canonical basis corrispondent to min index gradient
        index_min = torch.argmin(grad_s_it)
        index_max = torch.argmax(grad_s_it * torch.abs(torch.sign(s_it)))
        sign_min = (
            self.fw_state["tau"] - self.fw_state["eps"]
            if torch.sign(grad_s_it[index_min]) < 0
            else self.fw_state["eps"]
        )  # 0.
        sign_max = (
            self.fw_state["tau"] - self.fw_state["eps"]
            if torch.sign(grad_s_it[index_max]) > 0
            else self.fw_state["eps"]
        )  # 0.
        s_it_min = self.fw_state["eps"] * torch.ones(s_it.shape[0], device=s_it.device)
        s_it_min[index_min].copy_(sign_min)
        s_it_max = self.fw_state["eps"] * torch.ones(
            s_it.shape[0], device=s_it.device
        )  ##eps instead of zero
        s_it_max[index_max].copy_(sign_max)
        # print(f's+,s- satisfies constraints {torch.sum(torch.abs(s_it_max))<=self.fw_state['scaled_tau'],torch.sum(torch.abs(s_it_min))<=self.fw_state['scaled_tau']}')
        d_AS = s_it - s_it_max
        d_FW = s_it_min - s_it
        # descent direction
        if torch.dot(grad_s_it, d_FW) <= torch.dot(grad_s_it, d_AS):
            d_it = d_FW
            step = 1.0
        else:
            d_it = d_AS
            step = calculate_max_step_as(
                s_it, d_it, tau=self.fw_state["tau"], eps=self.fw_state["eps"]
            )
        self.fw_state["step"] = step
        self.fw_state["d_it"] = d_it
        self.fw_state["fun_s_it"] = fun_s_it
        self.fw_state["grad_s_it"] = grad_s_it

    @torch.no_grad()
    def training_step(
        self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        
        #### evaluate inside training ###
        if self.fw_state["it"]%5 == 0:
            self.print_evaluation()
        ###############################

        s_it, grad_s_it, d_it,fun_s_it,step = (
            self.fw_state["s_it"],
            self.fw_state["grad_s_it"],
            self.fw_state["d_it"],
            self.fw_state["fun_s_it"],
            self.fw_state["step"]
        )
        if (
            (self.fw_state["it"] < self.total_epochs
            and torch.dot(grad_s_it, d_it) <= -self.fw_state["tol"]) or self.fw_state["it"] <= self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1
        ):
            self.fw_state["it"] += 1
            ####### loss cumulator
            # Armijo line search
            print(f"armijo...")
            step, s_it = self.armijo_line_search(
                d_it,
                s_it,
                fun_s_it,
                grad_s_it,
                start_step=step,
                batch=self.fw_state["current_batch"],
            )
            print(f"end armijo!")
            #### update state
            self.low_rank_layers = update_state(s_it, self.low_rank_layers)
            ### pruning
            if (
                self.fw_state["it"] >= self.fw_state["it_init_prune"]
                and self.fw_state["it"]
                < self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1
            ):
                print(f"in train prune...")
                with torch.no_grad():
                    s_it, self.low_rank_layers = in_train_prune(
                        s_it,
                        self.fw_state["final_cr"],
                        n_steps=self.fw_state["pruning_steps"],
                        it=self.fw_state["it"] - self.fw_state["it_init_prune"],
                        low_rank_layers=self.low_rank_layers,
                        riemannian = self.riemannian
                    )
                    self.mask = (s_it != 0.0).bool()
                    print_sparsity(s_it)
                    print_total_params(self.low_rank_layers)
            # gradient - approximation and new loss value
            with torch.enable_grad():
                fun_s_it, _, grad_s_it = self.get_hypergradient(
                    model=model, inputs=inputs
                )
            # minimize linearize function
            # unit simplex #vector of the +- canonical basis corrispondent to min index gradient
            index_min = torch.argmin(grad_s_it)
            index_max = torch.argmax(grad_s_it * torch.abs(torch.sign(s_it)))
            sign_min = (
                self.fw_state["tau"] - self.fw_state["eps"]
                if torch.sign(grad_s_it[index_min]) < 0
                else self.fw_state["eps"]
            )  # 0.
            sign_max = (
                self.fw_state["tau"] - self.fw_state["eps"]
                if torch.sign(grad_s_it[index_max]) > 0
                else self.fw_state["eps"]
            )  # 0.
            s_it_min = self.fw_state["eps"] * torch.ones(
                s_it.shape[0], device=s_it.device
            )  ###TODO: to keep checked in which device do these tensors go
            s_it_min[index_min].copy_(sign_min)
            s_it_max = self.fw_state["eps"] * torch.ones(
                s_it.shape[0], device=s_it.device
            )
            s_it_max[index_max].copy_(sign_max)
            s_it_min.mul_(self.mask)
            s_it_max.mul_(self.mask)

            d_AS = s_it - s_it_max
            d_FW = s_it_min - s_it
            # descent direction
            if torch.dot(grad_s_it, d_FW) <= torch.dot(grad_s_it, d_AS):
                d_it = d_FW
                step = 1.0
            else:
                d_it = d_AS
                step = calculate_max_step_as(
                    s_it, d_it, self.fw_state["tau"], self.fw_state["eps"]
                )
            print(
                f"s satisfied constraints {torch.sum(torch.abs(s_it))<=self.fw_state['scaled_tau'],torch.sum(s_it<0.0)==0}"
            )
            print("-" * 100)
            #### updata fw state
            self.fw_state["current_batch"] = inputs
            self.fw_state["s_it"] = s_it
            self.fw_state["grad_s_it"] = grad_s_it
            self.fw_state["d_it"] = d_it
            self.fw_state["step"] = step
            self.fw_state["fun_s_it"] = fun_s_it
            print(f'grad norm, d_it norm {torch.linalg.norm(grad_s_it),torch.linalg.norm(d_it)}')
            return fun_s_it
        else:
            print(f'converged! grad norm, d_it norm {torch.linalg.norm(grad_s_it),torch.linalg.norm(d_it)}')
            if self.fw_state["it"] >= self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1:  ### let the pruning end and then stop training
                self.should_epoch_stop = True
                wandb.finish()
            return self.fw_state["fun_s_it"]
        
    def train(self):
        for epochs in range(self.total_epochs):
            for data in self.train_loader:
                self.training_step(self.model,data)
                if self.should_epoch_stop:
                    break

        
    @torch.no_grad()
    def print_evaluation(self):
        self.model.eval()
        total = 0.0
        correct = 0.0
        total_loss = 0.0
        for input,targets in self.val_loader:
            input = input.to(self.device)
            targets = targets.to(self.device)
            outputs = self.model(input)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
            total_loss+=self.criterion(outputs,targets).item()
        accuracy = 100 * correct / total
        wandb.log({'val_acc':accuracy,'val_loss':total_loss})
    


class adalora_bilevel_trainer_stablediff:

    """
    Custom trainer class for the Bilevel low-rank FW optimizer
    """

    def __init__(
        self,
        models,
        train_dataset,
        low_rank_layers,
        max_epochs_ll=2,
        tau=10,
        optimizer_and_scheduler = [],
        eps=1e-4,
        riemannian=False,
        final_cr=0.0,
        pruning_steps=10,
        task_name = 'stable_diff',
        epochs = 10,
        accelerator = None,
        args = None
    ):

        super(adalora_bilevel_trainer).__init__()
        self.models = models
        self.low_rank_layers = low_rank_layers
        s_temp = get_s(self.low_rank_layers)
        #######--------- initial constraint
        with torch.no_grad():
            print(f'initial check constraints {torch.sum(torch.abs(s_temp))<= tau,torch.sum(s_temp<eps)==0}')
            if torch.sum(torch.abs(s_temp))>= tau or torch.sum(s_temp<eps)!=0:
                print(f'constraints not satisfied')
                s_temp/= (torch.sum(torch.abs(s_temp))/(0.5*tau))
                s_temp+= eps
            self.low_rank_layers = update_state(s_temp,self.low_rank_layers)
            print(f'updates s satisfies constraints: {torch.sum(torch.abs(s_temp))<=tau,torch.sum(s_temp<eps)==0}')
        ####### 
        self.args = args
        self.problem_dim = s_temp.shape[0]
        self.device = s_temp.device
        self.batch_size = train_dataset.batch_size
        self.max_epochs_ll = max_epochs_ll
        num_update_steps_per_epoch = len(train_dataset)
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        self.total_epochs = epochs * num_update_steps_per_epoch
        self.riemannian = riemannian
        self.optimizer_UV = optimizer_and_scheduler['optimizer']
        self.scheduler_UV = optimizer_and_scheduler['scheduler']
        self.noise_scheduler = optimizer_and_scheduler['noise_scheduler']
        self.total_epochs = epochs ########### fix teìhe number of fw
        self.should_epoch_stop = False
        set_debug_apis(False)
        print(f"problem dimension: {self.problem_dim}")
        # print(f"pre-training validation metrics: {self.evaluate(self.eval_dataset)}")
        ##### initialization Franke-Wolfe variables #########
        self.mask = torch.ones(
            get_s(self.low_rank_layers).shape[0], device=self.device
        ).bool()
        # self.retract_Oblique()
        self.fw_state = {
            "s_it": s_temp,
            "it": torch.tensor(0, device=self.device, dtype=torch.int32),
            "it_init_prune": torch.tensor(1, device=self.device, dtype=torch.int32),
            "tau": torch.tensor(tau, device=self.device),
            "scaled_tau": tau
            + torch.tensor(self.problem_dim * eps, device=self.device),
            "eps": torch.tensor(eps, device=self.device),
            "tol": 1e-5,
            "final_cr": final_cr,
            "pruning_steps": pruning_steps,
        }
        #### dataset creation
        self.train_loader = train_dataset
        #### initialize accelerator
        self.accelerator = accelerator
        self.initialize_fw()
        wandb.init(
            project=f'blo_{task_name}',
            config = {'task':task_name}
        )
        print(f'TRAINER INITIALIZED!\n')
        print(f'training for {self.total_epochs} franke-wolfe steps')

    @torch.no_grad()
    def get_Rgrad_Oblique(self):
        """
        New function, give the euclidean gradients computes the Riemannian gradients on the oblique manifold w.r.t the ambient metric (equivalent to product metric of spheres)
        """
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                mask = (l.s != 0.0).bool()   ###.byte() deprecated
                l.us.grad[:, mask].copy_(
                    l.us.grad[:, mask]
                    - l.us[:, mask]
                    @ torch.diag(torch.diag(l.us[:, mask].T @ l.us.grad[:, mask]))
                )
                l.vs.grad[:, mask].copy_(
                    l.vs.grad[:, mask]
                    - l.vs[:, mask]
                    @ torch.diag(torch.diag(l.vs[:, mask].T @ l.vs.grad[:, mask]))
                )
            elif isinstance(l, low_rank_CP):
                mask = (l.s != 0.0).bool()
                for u in l.us:
                    u.grad[:, mask].copy_(
                        u.grad[:, mask]
                        - u[:, mask]
                        @ torch.diag(torch.diag(u[:, mask].T @ u.grad[:, mask]))
                    )
            else:
                raise TypeError()

    @torch.no_grad()
    def get_landing_field_stiefel(self, lamb=1e-2):
        """
        New function, give the euclidean gradients computes Landing field on product of stiefel manifolds
        """
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                l.us.grad.copy_(
                    l.us.grad
                    - 0.5 * l.us @ (l.us.T @ l.us.grad + l.us.grad.T @ l.us)
                    + lamb * (l.us @ l.us.T @ l.us - l.us)
                )
                l.vs.grad.copy_(
                    l.vs.grad
                    - 0.5 * l.vs @ (l.vs.T @ l.vs.grad + l.vs.grad.T @ l.vs)
                    + lamb * (l.vs @ l.vs.T @ l.vs - l.vs)
                )
            elif isinstance(l, low_rank_CP):
                for u in l.us:
                    u.grad.copy_(
                        u.grad
                        - 0.5 * u @ (u.T @ u.grad + u.grad.T @ u)
                        + lamb * (u @ u.T @ u - u)
                    )
            else:
                raise TypeError()

    @torch.no_grad()
    def retract_Oblique(self):
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                mask = (l.s != 0.0).bool()
                l.us[:, mask] = l.us[:, mask] / torch.linalg.norm(l.us[:, mask], dim=0)
                l.vs[:, mask] = l.vs[:, mask] / torch.linalg.norm(l.vs[:, mask], dim=0)
            elif isinstance(l, low_rank_CP):
                mask = (l.s != 0.0).bool()
                for u in l.us:
                    u[:, mask] = u[:, mask] / torch.linalg.norm(u[:, mask], dim=0)
            else:
                raise TypeError()

    @torch.no_grad()
    def retract_Stiefel(self):
        for l in self.low_rank_layers:
            if isinstance(l, low_rank_linear):
                l.us.copy_(torch.linalg.qr(l.us, mode="reduced")[0])
                l.vs.copy_(torch.linalg.qr(l.vs, mode="reduced")[0])
            elif isinstance(l, low_rank_CP):
                for u in l.us:
                    u[:,:min(u.shape[0],u.shape[1])].copy_(torch.linalg.qr(u, mode="reduced")[0])
            else:
                raise TypeError()
            
    def compute_loss(self,models,batch):
        unet,vae,text_encoder = models
        noise_scheduler = self.noise_scheduler
        args = self.args
        with self.accelerator.accumulate(unet):
            # Convert images to latent space
            latents = vae.encode(batch["pixel_values"].to(dtype=torch.float32)).latent_dist.sample()
            latents = latents * 0.18215

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
            )
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]

            # Predict the noise residual
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            if args.with_prior_preservation:
                # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)

                # Compute instance loss
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                # Compute prior loss
                prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

                # Add the prior loss to the instance loss.
                loss = loss + args.prior_loss_weight * prior_loss
            else:
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        return loss


    @torch.no_grad()
    def armijo_line_search(
        self, d_it, s_it, fun_s_it, grad_s_it, start_step=1.0, batch=None
    ):
        """
        Stochastic armijo line search on the batch databatch. Returns optimal step and optimal point
        d: descent direction
        s_it : current evaluation point
        ts: flax train state (updated during armijo and returned at the end)
        basis: current basis (fixed) during the search
        L : objective function
        data_batch : batch of data on which the loss is averaged for the line search
        fun_s_it : L(s_it,basis,data_batch)
        grad_s_it : hypergradient_s L(s_it,basis,databatch)
        Line search converges, but add a counter just for numerical stability
        """
        step = start_step
        s_it_d = s_it + step * d_it
        fun_s_it_d = self.batch_loss(batch)  # self.full_loss()
        counter = 0
        while (
            fun_s_it_d - fun_s_it >= 1e-1 * step * torch.dot(d_it, grad_s_it)
            and counter <= 15
        ):
            step = step / 2
            s_it_d = s_it + step * d_it  # update
            self.low_rank_layers = update_state(s_it_d, self.low_rank_layers)
            fun_s_it_d = self.batch_loss(batch)  # self.full_loss()
            counter += 1
        return step, s_it_d

    @torch.no_grad()
    def batch_loss(self, data):
        if data is not None:
            return self.compute_loss(self.models, data)
        else:
            return self.full_loss()

    def get_U_star(self):
        """
        Computes U_star(s) for the hypergradient calculation
        """
        for m in self.models:
            m.train()
        for layer in self.low_rank_layers:
            layer.activate_lower_level()
        for _ in range(self.max_epochs_ll):
            loss_fn = 0.0
            for batch in self.train_loader:
                for m in self.models:
                    for p in m.parameters():
                        if p.requires_grad:
                            p.grad = None
                with torch.cuda.amp.autocast():
                    loss = self.compute_loss(self.models, batch)
                # self.scaler.scale(loss).backward()
                self.accelerator.backward(loss)
                loss_fn += float(loss.item())
                if self.riemannian:
                    if self.riemannian.lower() == "oblique":
                        self.get_Rgrad_Oblique()
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                        self.retract_Oblique()
                    elif self.riemannian.lower() == "stiefel_landing":
                        self.get_landing_field_stiefel()
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                    elif self.riemannian.lower() == "stiefel":
                        # self.scaler.step(self.optimizer_UV)
                        self.optimizer_UV.step()
                        self.retract_Stiefel()
                else:
                    # self.scaler.step(self.optimizer_UV)
                    self.optimizer_UV.step()
                # self.scaler.update()
            self.scheduler_UV.step()  # (total_loss)   ## no use of scheduler for the moment
        for m in self.models:
            m.eval()
        print(f"finish lower level with loss: {loss_fn}")
        self.accelerator.log({'loss':loss_fn})


    def get_hypergradient(self, models, inputs, already_optimal=False):
        """
        Computes full hypergradient nabla_s L^* on the whole dataset
        """
        if not already_optimal:
            self.get_U_star()
            # self.train_loader = self.get_train_dataloader()
        for layer in self.low_rank_layers:
            layer.activate_upper_level()
        for m in models:
            for p in m.parameters():
                if p.requires_grad:
                    p.grad = None
        loss_fn = 0.0
        accuracy_fn = 0.0
        with torch.cuda.amp.autocast():
            loss = self.compute_loss(models, inputs)
        # self.scaler.scale(loss).backward()  # loss.backward()
        self.accelerator.backward(loss)
        with torch.no_grad():
            loss_fn += float(loss.item())
        for l in self.low_rank_layers:
            l.get_hypergradient()
        print(f"calculated U^*, metrics: {loss_fn,accuracy_fn}")
        self.fw_state["current_batch"] = inputs
        h_grad = get_grad_s(self.low_rank_layers)
        for m in self.models:
            m.zero_grad()
        return loss_fn, None, h_grad

    @torch.no_grad()
    def initialize_fw(self):
        s_it = self.fw_state["s_it"]
        inputs = next(iter(self.train_loader))
        self.fw_state["current_batch"] = inputs
        #### hypergradient calculation
        with torch.enable_grad():
            fun_s_it, _, grad_s_it = self.get_hypergradient(self.models,inputs)
        # print(f's satisfied constraints zero {torch.sum(torch.abs(s_it))<=self.fw_state['scaled_tau'],torch.sum(s_it<self.fw_state['eps'])==0}')
        # minimize linearize function
        # unit simplex #vector of the +- canonical basis corrispondent to min index gradient
        index_min = torch.argmin(grad_s_it)
        index_max = torch.argmax(grad_s_it * torch.abs(torch.sign(s_it)))
        sign_min = (
            self.fw_state["tau"] - self.fw_state["eps"]
            if torch.sign(grad_s_it[index_min]) < 0
            else self.fw_state["eps"]
        )  # 0.
        sign_max = (
            self.fw_state["tau"] - self.fw_state["eps"]
            if torch.sign(grad_s_it[index_max]) > 0
            else self.fw_state["eps"]
        )  # 0.
        s_it_min = self.fw_state["eps"] * torch.ones(s_it.shape[0], device=s_it.device)
        s_it_min[index_min].copy_(sign_min)
        s_it_max = self.fw_state["eps"] * torch.ones(
            s_it.shape[0], device=s_it.device
        )  ##eps instead of zero
        s_it_max[index_max].copy_(sign_max)
        # print(f's+,s- satisfies constraints {torch.sum(torch.abs(s_it_max))<=self.fw_state['scaled_tau'],torch.sum(torch.abs(s_it_min))<=self.fw_state['scaled_tau']}')
        d_AS = s_it - s_it_max
        d_FW = s_it_min - s_it
        # descent direction
        if torch.dot(grad_s_it, d_FW) <= torch.dot(grad_s_it, d_AS):
            d_it = d_FW
            step = 1.0
        else:
            d_it = d_AS
            step = calculate_max_step_as(
                s_it, d_it, tau=self.fw_state["tau"], eps=self.fw_state["eps"]
            )
        self.fw_state["step"] = step
        self.fw_state["d_it"] = d_it
        self.fw_state["fun_s_it"] = fun_s_it
        self.fw_state["grad_s_it"] = grad_s_it

    @torch.no_grad()
    def training_step(
        self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        

        s_it, grad_s_it, d_it,fun_s_it,step = (
            self.fw_state["s_it"],
            self.fw_state["grad_s_it"],
            self.fw_state["d_it"],
            self.fw_state["fun_s_it"],
            self.fw_state["step"]
        )
        if (
            (self.fw_state["it"] < self.total_epochs
            and torch.dot(grad_s_it, d_it) <= -self.fw_state["tol"]) or self.fw_state["it"] <= self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1
        ):
            self.fw_state["it"] += 1
            ####### loss cumulator
            # Armijo line search
            print(f"armijo...")
            step, s_it = self.armijo_line_search(
                d_it,
                s_it,
                fun_s_it,
                grad_s_it,
                start_step=step,
                batch=self.fw_state["current_batch"],
            )
            print(f"end armijo!")
            #### update state
            self.low_rank_layers = update_state(s_it, self.low_rank_layers)
            ### pruning
            if (
                self.fw_state["it"] >= self.fw_state["it_init_prune"]
                and self.fw_state["it"]
                < self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1
            ):
                print(f"in train prune...")
                with torch.no_grad():
                    s_it, self.low_rank_layers = in_train_prune(
                        s_it,
                        self.fw_state["final_cr"],
                        n_steps=self.fw_state["pruning_steps"],
                        it=self.fw_state["it"] - self.fw_state["it_init_prune"],
                        low_rank_layers=self.low_rank_layers,
                        riemannian = self.riemannian
                    )
                    self.mask = (s_it != 0.0).bool()
                    print_sparsity(s_it)
                    print_total_params(self.low_rank_layers)
            # gradient - approximation and new loss value
            with torch.enable_grad():
                fun_s_it, _, grad_s_it = self.get_hypergradient(
                    models=model, inputs=inputs
                )
            # minimize linearize function
            # unit simplex #vector of the +- canonical basis corrispondent to min index gradient
            index_min = torch.argmin(grad_s_it)
            index_max = torch.argmax(grad_s_it * torch.abs(torch.sign(s_it)))
            sign_min = (
                self.fw_state["tau"] - self.fw_state["eps"]
                if torch.sign(grad_s_it[index_min]) < 0
                else self.fw_state["eps"]
            )  # 0.
            sign_max = (
                self.fw_state["tau"] - self.fw_state["eps"]
                if torch.sign(grad_s_it[index_max]) > 0
                else self.fw_state["eps"]
            )  # 0.
            s_it_min = self.fw_state["eps"] * torch.ones(
                s_it.shape[0], device=s_it.device
            )  ###TODO: to keep checked in which device do these tensors go
            s_it_min[index_min].copy_(sign_min)
            s_it_max = self.fw_state["eps"] * torch.ones(
                s_it.shape[0], device=s_it.device
            )
            s_it_max[index_max].copy_(sign_max)
            s_it_min.mul_(self.mask)
            s_it_max.mul_(self.mask)

            d_AS = s_it - s_it_max
            d_FW = s_it_min - s_it
            # descent direction
            if torch.dot(grad_s_it, d_FW) <= torch.dot(grad_s_it, d_AS):
                d_it = d_FW
                step = 1.0
            else:
                d_it = d_AS
                step = calculate_max_step_as(
                    s_it, d_it, self.fw_state["tau"], self.fw_state["eps"]
                )
            print(
                f"s satisfied constraints {torch.sum(torch.abs(s_it))<=self.fw_state['scaled_tau'],torch.sum(s_it<0.0)==0}"
            )
            print("-" * 100)
            #### updata fw state
            self.fw_state["current_batch"] = inputs
            self.fw_state["s_it"] = s_it
            self.fw_state["grad_s_it"] = grad_s_it
            self.fw_state["d_it"] = d_it
            self.fw_state["step"] = step
            self.fw_state["fun_s_it"] = fun_s_it
            print(f'grad norm, d_it norm {torch.linalg.norm(grad_s_it),torch.linalg.norm(d_it)}')
            return fun_s_it
        else:
            print(f'converged! grad norm, d_it norm {torch.linalg.norm(grad_s_it),torch.linalg.norm(d_it)}')
            if self.fw_state["it"] >= self.fw_state["it_init_prune"] + self.fw_state["pruning_steps"] + 1:  ### let the pruning end and then stop training
                self.should_epoch_stop = True
                wandb.finish()
            return self.fw_state["fun_s_it"]

    
