import torch
from torch.autograd import Variable

from .solver import MinNormSolver

class MGDABalancer:
    def __init__(self, scale_decoder_grad=False, compute_cnumber=False) -> None:
        self.scale_decoder_grad = scale_decoder_grad
        self.losses = None
        self.weights = None
        self.compute_cnumber = compute_cnumber
        self.cnumber = None

    def step(self, input, targets, encoder, decoders, criteria, **kwargs):
        encoder.zero_grad()
        decoders.zero_grad()

        losses, grads = [], []
        hrepr = encoder(input)
        
        for id, decoder in decoders.items():
            # Be sure that your task heads don't have overlapping parameters!
            # Otherwise these parameters will be upddated using the latest 
            # backward gradient.
            encoder.zero_grad()
            losses.append(
                criteria[id](decoder(hrepr), targets[id])
            )
            losses[-1].backward(retain_graph=True)
            grads.append(
                torch.cat(
                    [
                        p.grad.flatten().detach().data.clone() 
                        for p in encoder.parameters() 
                        if p.grad is not None
                    ]
                )
            )
        grads = torch.stack(grads, dim=0)
        scales, _ = MinNormSolver.apply(grads)
        self.weights = scales
        grads = grads * scales.view(-1 ,1)

        if self.compute_cnumber is True:
            wgrads = grads.clone()
            _, _singulars, _ = torch.svd(wgrads, compute_uv=False)
            self.cnumber = _singulars[0] / _singulars[-1]
        
        grad = torch.sum(grads, dim=0)

        offset = 0
        for p in encoder.parameters():
            if p.grad is None:
                continue
            _offset = offset + p.grad.shape.numel()
            p.grad.data = grad[offset:_offset].view_as(p.grad)
            offset = _offset
        
        if self.scale_decoder_grad is True:
            for i, decoder in enumerate(decoders.values()):
                for p in decoder.parameters():
                    if p.grad is not None:
                        p.grad.mul_(self.weights[i])

        self.losses = [value.item() for value in losses]


class MGDAUBBalancer:
    def __init__(self, scale_decoder_grad=False, compute_cnumber=False) -> None:
        self.scale_decoder_grad = scale_decoder_grad
        self.weights = None
        self.losses = None
        self.compute_cnumber = compute_cnumber
        self.cnumber = None
    
    def step(self, input, targets, encoder, decoders, criteria, **kwargs):

        encoder.zero_grad()
        decoders.zero_grad()
        hrepr = encoder(input)

        _hrepr = Variable(hrepr.data.clone(), requires_grad=True)
        losses = [
            criteria[id](decoder(_hrepr), targets[id])
            for id, decoder in decoders.items()
        ]
        grads = list()

        for loss in losses:
            # Be sure that your task heads don't have overlapping parameters!
            # Otherwise these parameters will be upddated using the latest 
            # backward gradient.
            loss.backward(retain_graph=False)
            grads.append(
                Variable(_hrepr.grad.data.clone(), requires_grad=False)
            )
            _hrepr.grad.data.zero_()

        self.losses = [value.item() for value in losses]
        
        grads = torch.stack(grads, dim=0)
        shape = grads.shape
        grads = grads.reshape(shape[0], -1)

        scales, _ = MinNormSolver.apply(grads)
        self.weights = scales

        grads = grads * scales.view(-1, 1)

        if self.compute_cnumber is True:
            # Computationally expensive, use it for demonstration only.
            # This code can't be executed with real training pipelines.
            wgrads = list()
            for t in range(grads.shape[0]):
                hrepr.backward(grads[t].view_as(hrepr), retain_graph=True)

                wgrads.append(
                    torch.cat(
                        [
                            p.grad.flatten().detach().data.clone() 
                            for p in encoder.parameters() 
                            if p.grad is not None
                        ]
                    )
                )
                encoder.zero_grad()
            wgrads = torch.stack(wgrads, dim=-1)
            _, _singulars, _ = torch.svd(wgrads, compute_uv=False)
            self.cnumber = _singulars[0] / _singulars[-1]

        grad = torch.sum(grads, dim=0)

        hrepr.backward(grad.view_as(hrepr))
    
        if self.scale_decoder_grad is True:
            for i, decoder in enumerate(decoders.values()):
                for p in decoder.parameters():
                    if p.grad is not None:
                        p.grad.mul_(self.weights[i])