import torch

from .solver import RandomProjectionSolver


class PCGradBalancer:
    def __init__(self, compute_cnumber=False) -> None:
        self.compute_cnumber=compute_cnumber
        self.cnumber = None
        self.losses = None

    def step(self, input, targets, encoder, decoders, criteria, **kwargs):
        encoder.zero_grad()
        decoders.zero_grad()

        losses, grads = [], []
        hrepr = encoder(input)
        
        for id, decoder in decoders.items():
            # Be sure that your task heads don't have overlapping parameters!
            # Otherwise these parameters will be upddated using the latest 
            # backward gradient.
            encoder.zero_grad()
            losses.append(
                criteria[id](decoder(hrepr), targets[id])
            )
            losses[-1].backward(retain_graph=True)
            grads.append(
                torch.cat(
                    [
                        p.grad.flatten().detach().data.clone() 
                        for p in encoder.parameters() 
                        if p.grad is not None
                    ]
                )
            )
        self.losses = [loss.item() for loss in losses]
        grads = torch.stack(grads, dim=0)

        proj_grads = RandomProjectionSolver.apply(grads)

        if self.compute_cnumber is True:
            wgrads = proj_grads.clone()
            _, _singulars, _ = torch.svd(wgrads.squeeze(), compute_uv=False)
            self.cnumber = _singulars[0] / _singulars[-1]

        grad = proj_grads.sum(0)

        offset = 0
        for p in encoder.parameters():
            if p.grad is None:
                continue
            _offset = offset + p.grad.shape.numel()
            p.grad.data = grad[offset:_offset].view_as(p.grad)
            offset = _offset