import torch
from federatedscope.core.aggregators import ClientsAvgAggregator
from copy import deepcopy
from scipy.optimize import linear_sum_assignment  # pip install scipy


class SVDAggregator(ClientsAvgAggregator):
    """
    Implementation of vanilla FedAvg refer to 'Communication-efficient \
    learning of deep networks from decentralized data' [McMahan et al., 2017] \
    http://proceedings.mlr.press/v54/mcmahan17a.html
    """
    def __init__(self, model=None, device='cpu', config=None):
        super(ClientsAvgAggregator, self).__init__()
        self.model = model
        self.device = device
        self.cfg = config

    def _para_weighted_avg(self, models, recover_fun=None):
        """
        Calculates the weighted average of models.
        """
        # return self._para_aligned_avg(models, recover_fun)
        # return self._new_avg2(models, recover_fun=recover_fun)


        training_set_size = 0
        for i in range(len(models)):
            sample_size, _ = models[i]
            training_set_size += sample_size

        sample_size, avg_model = deepcopy(models[0])
        for key in avg_model:
            if 'lora_A' in key:
                A = torch.stack([model[key] for _, model in models], dim=0).reshape(-1, 1024).float()
                U, S, V = torch.svd_lowrank(A, q=8, niter=3)
                Vt = V.t()
                avg_model[key] =  Vt.half()
                # avg_model[key] =  torch.diag(S/len(models)) @ Vt
                # avg_model[key] =  torch.mean(U.reshape((1,8,8)), dim=0) @ torch.diag(S) @ Vt
                # avg_model[key] =  torch.mean(torch.stack([model[key] for _, model in models], dim=0), dim=0)

        for key in avg_model:
            if 'lora_B' in key:
                a_key = key.replace('lora_B', 'lora_A')
                # avg_model[key] = torch.mean(model[key] @ model[a_key] @ torch.linalg.pinv(avg_model[a_key].float()).half() for _, model in models)
                avg_model[key] = torch.mean(torch.stack([model[key] @ model[a_key] @ avg_model[a_key].t() for _, model in models], dim=0), dim=0)

        return avg_model




    def _para_aligned_avg(self, models, recover_fun=None):
        """
        Aligns the 8×1024 weight matrices of all clients by row permutation
        (Hungarian assignment w.r.t. L2 distance to the first client) and
        then returns their sample‑size–weighted average.

        Args
        ----
        models : [(sample_size, state_dict), …]
        recover_fun : optional post‑processing hook applied to each averaged tensor
        """
        # ———————————————————— bookkeeping ————————————————————
        total_samples = sum(sz for sz, _ in models)
        _, avg_model = deepcopy(models[0])          # clone first client’s state‑dict

        # ———————————————————— per‑parameter alignment+average ————————————————————
        for key in avg_model:                       # each key is (8,1024)
            reference = models[0][1][key].float()   # no permutation for client‑0

            # running weighted sum of aligned matrices
            aligned_sum = reference * (models[0][0] / total_samples)

            for sample_size, client_state in models[1:]:
                mat = client_state[key].float()     # (8,1024)

                # ----- build 8×8 cost matrix (squared L2 between rows) -----
                diff = mat.unsqueeze(1) - reference.unsqueeze(0)   # (8,8,1024)
                cost = (diff ** 2).sum(-1).cpu().numpy()           # numpy for SciPy

                # ----- optimal assignment (row_i → ref_row_j) -----
                row_ind, col_ind = linear_sum_assignment(cost)     # Hungarian
                # reorder rows so they match reference row order 0..7
                perm = torch.as_tensor(row_ind)[torch.argsort(torch.as_tensor(col_ind))]
                aligned = mat[perm]                                # (8,1024)

                aligned_sum += aligned * (sample_size / total_samples)

            # optional extra processing
            avg_model[key] = recover_fun(aligned_sum) if recover_fun else aligned_sum

        return avg_model


    def _new_avg(self, models, recover_fun=None):
        training_set_size = 0
        for i in range(len(models)):
            sample_size, _ = models[i]
            training_set_size += sample_size

        sample_size, avg_model = deepcopy(models[0])
        for key in avg_model:
            if 'lora_A' in key:
                bkey = key.replace('lora_A', 'lora_B')
                W = [model[bkey] @ model[key] for _, model in models]
                W_avg = torch.mean(torch.stack(W, dim=0), dim=0).float()
                U, S, V = torch.svd_lowrank(W_avg, q=8, niter=3)
                Vt = V.t()
                S = torch.diag(S)
                A = torch.sqrt(S) @ Vt
                B = U @ torch.sqrt(S)
                avg_model[key] = A.half()
                avg_model[bkey] = B.half()

        return avg_model

    def _new_avg2(self, models, recover_fun=None):
        training_set_size = 0
        for i in range(len(models)):
            sample_size, _ = models[i]
            training_set_size += sample_size

        sample_size, avg_model = deepcopy(models[0])
        for key in avg_model:
            if 'lora_A' in key:
                bkey = key.replace('lora_A', 'lora_B')
                AB = torch.stack([torch.concat([model[key], model[bkey].T], dim=0) for _, model in models], dim=0).reshape(-1, 2048).float()
                U, S, V = torch.svd_lowrank(AB, q=8, niter=3)
                Vt = V.t()
                S = torch.diag(S)
                # A = torch.mean(torch.stack([U[i*8:(i+1)*8] @ S @ Vt[:, :1024] for i in range(len(models))], dim=0), dim=0)
                # B = torch.mean(torch.stack([Vt[:, 1024:].T @ S @ U[i*8:(i+1)*8].T for i in range(len(models))], dim=0), dim=0)
                A = Vt[:, :1024]
                B = torch.mean(torch.stack([Vt[:, 1024:].T @ S @ S @ U[i*8:(i+1)*8].T @ U[i*8:(i+1)*8] for i in range(len(models))], dim=0), dim=0)

                avg_model[key] = A.half()
                avg_model[bkey] = B.half()

        return avg_model