import torch

class HomoscedasticUncertaintyBalancer:
    def __init__(self, compute_cnumber=False) -> None:
        self.weights = None
        self.losses = None
        self.compute_cnumber = compute_cnumber
        self.cnumber = None

    def step(self, input, targets, encoder, decoders, criteria, **kwargs):
        encoder.zero_grad()
        decoders.zero_grad()

        losses = []
        total_loss = 0.0
        hrepr = encoder(input)
        
        if self.compute_cnumber:
            grads = list()
            for id, decoder in decoders.items():
                # Be sure that your task heads don't have overlapping parameters!
                # Otherwise these parameters will be upddated using the latest 
                # backward gradient.
                precision = torch.exp(-decoder.log_var)
                loss = precision * criteria[id](decoder(hrepr), targets[id]) + 0.5 * decoder.log_var
                encoder.zero_grad()
                loss.backward(retain_graph=True)
                grads.append(
                    torch.cat(
                        [
                            p.grad.flatten().detach().data.clone() 
                            for p in encoder.parameters() 
                            if p.grad is not None
                        ]
                    )
                )
                encoder.zero_grad()
                decoder.zero_grad()
            grads = torch.stack(grads, dim=-1)
            _, _singulars, _ = torch.svd(grads.squeeze(), compute_uv=False)
            self.cnumber = _singulars[0] / _singulars[-1]

        for id, decoder in decoders.items():
            encoder.zero_grad()
            precision = torch.exp(-decoder.log_var)
            loss = precision * criteria[id](decoder(hrepr), targets[id]) + 0.5 * decoder.log_var
            losses.append(loss)
            total_loss = total_loss + loss
        
        total_loss.backward()

        self.losses = [value.item() for value in losses]