import torch
import logging
try:
    import deepspeed
    from deepspeed import DeepSpeedEngine
except:
    deepspeed = None
    DeepSpeedEngine = None
from federatedscope.register import register_trainer

from federatedscope.glue.trainer.trainer import GLUETrainer
from copy import deepcopy
from federatedscope.core.auxiliaries.utils import param2tensor, merge_param_dict

logger = logging.getLogger(__name__)


class SVDTrainer(GLUETrainer):

    def update(self, model_parameters, strict=False):
        """
            Called by the FL client to update the model parameters
        Arguments:
            model_parameters (dict): PyTorch Module object's state_dict.
        """
        # model_parameters = deepcopy(model_parameters)
        old_A, new_A, old_B = None, None, None
        for key in model_parameters:
            model_parameters[key] = param2tensor(model_parameters[key])
            if "lora_A" in key:
                old_A = self.ctx.model.state_dict()[key]
                new_A = model_parameters[key]
            if "lora_B" in key:
                old_B = self.ctx.model.state_dict()[key]
                new_B = old_B @ old_A @ torch.linalg.pinv(new_A.float()).half()
                # new_B = old_B @ old_A @ new_A.T
                model_parameters[key] = new_B

        #
        # import wandb
        # wandb.log({'A':torch.norm(old_A).item(), 'B':torch.norm(old_B).item(), 'new_A':torch.norm(new_A).item(), 'new_B':torch.norm(new_B).item()})
        merged_param = merge_param_dict(self.ctx.model.state_dict().copy(), self._param_filter(model_parameters, filter_keywords=['classifier']))
        self.ctx.model.load_state_dict(merged_param, strict=strict)

    def safe_matrix_sqrt(self, G, eps=1e-12):
        G = 0.5 * (G + G.T)
        e, V = torch.linalg.eigh(G.double())
        e_pos = torch.clamp(e, min=0.0)
        return (V * torch.sqrt(e_pos + eps)).matmul(V.T).half()

    # def _hook_on_fit_end(self, ctx):
    #     super(SVDTrainer, self)._hook_on_fit_end(ctx)
    #     # return
    #     model = ctx.model.state_dict().copy()
    #     for key in model:
    #         if "lora_A" in key:
    #             bkey = key.replace('lora_A', 'lora_B')
    #             G = model[bkey].T @ model[bkey]
    #             G2 = self.safe_matrix_sqrt(G, eps=1e-12)
    #             # print(torch.norm(G2).cpu().numpy())
    #             model[key] = G2 @ model[key]
    #
    #     self.ctx.model.load_state_dict(model, strict=False)

    def get_model_para(self):
        model = self.ctx.model.state_dict().copy()
        for key in model:
            if "lora_A" in key:
                bkey = key.replace('lora_A', 'lora_B')
                G = model[bkey].T @ model[bkey]
                G2 = self.safe_matrix_sqrt(G, eps=1e-12)
                # print(torch.norm(G2).cpu().numpy())
                model[key] = G2 @ model[key]

        if self.cfg.federate.process_num > 1 or \
                self.cfg.federate.share_local_model or \
                self.cfg.llm.deepspeed.use:
            return self._param_filter(model)
        else:
            return self._param_filter(model)

def call_svd_trainer(trainer_type):
    if trainer_type == 'svdtrainer':
        trainer_builder = SVDTrainer
        return trainer_builder


register_trainer('svdtrainer', call_svd_trainer)
