import torch
from torch.autograd import Variable


class GradNormBalancer:
    def __init__(self, alpha=2, compute_cnumber=False):
        self.alpha = alpha
        self.weights = None
        self.losses = None
        self.initial_losses = None
        self.compute_cnumber = compute_cnumber
        self.cnumber = None

    def step(self, input, targets, encoder, layer, decoders, criteria):
        encoder.zero_grad()
        layer.zero_grad()
        decoders.zero_grad()

        hrepr = encoder(input)

        _hrepr = Variable(hrepr.data.clone(), requires_grad=True)
        _shared = layer(_hrepr)

        grads = list()
        losses = list()

        for id, decoder in decoders.items():
            # Be sure that your task heads don't have overlapping parameters!
            # Otherwise these parameters will be upddated using the latest 
            # backward gradient.
            loss = criteria[id](decoder(_shared), targets[id])
            layer.zero_grad()
            loss.backward(retain_graph=True)
            grads.append(
                torch.cat(
                    [
                        p.grad.flatten().detach().data.clone() 
                        for p in layer.parameters() 
                        if p.grad is not None
                    ]
                )
            )
            _hrepr.grad.data.zero_()
            
            losses += [loss]
        
        weights = torch.stack([d.weight for d in decoders.values()])
        weights = weights * len(criteria) / torch.sum(weights)


        if self.compute_cnumber is True:
            # Computationally expensive, use it for demonstration only.
            # This code can't be executed with real training pipelines.
            shared = layer(hrepr)
            wgrads = list()
            ws = weights.clone().detach()
            for idx, (id, decoder) in enumerate(decoders.items()):
                loss = ws[idx]*criteria[id](decoder(shared), targets[id])
                loss.backward(retain_graph=True)
                
                encoder_grads = None
                if list(encoder.parameters()) != []:
                    encoder_grads = torch.cat([
                        p.grad.flatten().detach().data.clone() 
                        for p in encoder.parameters() 
                        if p.grad is not None
                    ])
                layer_grads = torch.cat([
                    p.grad.flatten().detach().data.clone() 
                    for p in layer.parameters() 
                    if p.grad is not None
                ])

                _wgrads = torch.cat([encoder_grads, layer_grads], dim=0) if encoder_grads else layer_grads
                wgrads.append(_wgrads)
                
                encoder.zero_grad()
                layer.zero_grad()
                decoder.zero_grad()
            wgrads = torch.stack(wgrads, dim=-1)
            _, _singulars, _ = torch.svd(wgrads, compute_uv=False)
            self.cnumber = _singulars[0] / _singulars[-1]


        grads = torch.stack(grads, dim=-1) * weights.view(1, -1)

        self.losses = [loss.item() for loss in losses]

        # set initial loss
        if self.initial_losses is None:
            self.initial_losses = [loss.clone().detach() for loss in losses]
        
        # inverse training rates
        itrates = torch.stack([
            losses[i].clone().detach() / self.initial_losses[i] 
            for i in range(len(losses))
        ])

        # mean inverse training rate
        mean_itrate = torch.mean(itrates)
        # relative inverse training rates
        itrates = itrates / mean_itrate
        # apply restoring force
        itrates = itrates.pow(self.alpha)

        norms = torch.norm(grads, dim=0, p=2)
        mean_norm  = torch.mean(norms).clone().detach()

        grad_loss = torch.sum(
            torch.abs(norms - itrates * mean_norm)
        )
        grad_loss.backward()

        shared = layer(hrepr)
        loss = 0.0
        for id, decoder in decoders.items():
            weight = decoder.weight.clone().detach()
            loss = loss + weight * criteria[id](decoder(shared), targets[id])
            
        loss.backward()
            
