import torch
from torch.autograd import Variable

from .solver import ProcrustesSolver

class ThetaAlignedBalancer:
    def __init__(self, scale_decoder_grad=False, unit_scale=False, compute_cnumber=False) -> None:
        self.scale_decoder_grad = scale_decoder_grad
        self.unit_scale = unit_scale
        self.singulars = None
        self.weights = None
        self.losses = None
        self.compute_cnumber = compute_cnumber
        self.cnumber = None

    def step(self, input, targets, encoder, decoders, criteria, **kwargs):
        encoder.zero_grad()
        decoders.zero_grad()

        losses, grads = [], []
        hrepr = encoder(input)
        
        for id, decoder in decoders.items():
            # Be sure that your task heads don't have overlapping parameters!
            # Otherwise these parameters will be upddated using the latest 
            # backward gradient.
            encoder.zero_grad()
            losses.append(
                criteria[id](decoder(hrepr), targets[id])
            )
            losses[-1].backward(retain_graph=True)
            grads.append(
                torch.cat(
                    [
                        p.grad.flatten().detach().data.clone() 
                        for p in encoder.parameters() 
                        if p.grad is not None
                    ]
                )
            )
        grads = torch.stack(grads, dim=-1)

        grads, weights, singulars = ProcrustesSolver.apply(
            grads.unsqueeze(0), self.unit_scale
        )
        grad, weights = grads[0].sum(-1), weights.sum(-1)

        self.singulars, self.weights = singulars, weights

        if self.compute_cnumber is True:
            wgrads = grads.clone()
            _, _singulars, _ = torch.svd(wgrads.squeeze(), compute_uv=False)
            # we expect condition number to be exactly one.
            self.cnumber = _singulars[0] / _singulars[-1]

        offset = 0
        for p in encoder.parameters():
            if p.grad is None:
                continue
            _offset = offset + p.grad.shape.numel()
            p.grad.data = grad[offset:_offset].view_as(p.grad)
            offset = _offset
        
        if self.scale_decoder_grad is True:
            for i, decoder in enumerate(decoders.values()):
                for p in decoder.parameters():
                    if p.grad is not None:
                        p.grad.mul_(weights[i])

        self.losses = [value.item() for value in losses]
        

class ZAlignedBalancer:
    def __init__(self, scale_decoder_grad=False, unit_scale=False, compute_cnumber=False) -> None:
        self.scale_decoder_grad = scale_decoder_grad
        self.unit_scale = unit_scale
        self.singulars = None
        self.weights = None
        self.losses = None
        self.compute_cnumber = compute_cnumber
        self.cnumber = None

    
    def step(self, input, targets, encoder, decoders, criteria, **kwargs):
        encoder.zero_grad()
        decoders.zero_grad()
        hrepr = encoder(input)

        _hrepr = Variable(hrepr.data.clone(), requires_grad=True)
        losses = [
            criteria[id](decoder(_hrepr), targets[id])
            for id, decoder in decoders.items()
        ]
        grads = list()

        for loss in losses:
            # Be sure that your task heads don't have overlapping parameters!
            # Otherwise these parameters will be upddated using the latest 
            # backward gradient.
            loss.backward(retain_graph=False)
            grads.append(
                Variable(_hrepr.grad.data.clone(), requires_grad=False)
            )
            _hrepr.grad.data.zero_()
        self.losses = [value.item() for value in losses]
        
        grads = torch.stack(grads, dim=-1)
        shape = torch.tensor(grads.shape)

        grads = grads.reshape(shape[0], torch.prod(shape[1:-1]), shape[-1])

        grads, weights, singulars = ProcrustesSolver.apply(
            grads, self.unit_scale
        )

        if self.compute_cnumber is True:
            # Computationally expensive, use it for demonstration only.
            # This code can't be executed with real training pipelines.
            wgrads = list()
            for t in range(grads.shape[-1]):
                hrepr.backward(grads[:, :, t].view_as(hrepr), retain_graph=True)

                wgrads.append(
                    torch.cat(
                        [
                            p.grad.flatten().detach().data.clone() 
                            for p in encoder.parameters() 
                            if p.grad is not None
                        ]
                    )
                )
                encoder.zero_grad()
            wgrads = torch.stack(wgrads, dim=-1)
            _, _singulars, _ = torch.svd(wgrads, compute_uv=False)
            self.cnumber = _singulars[0] / _singulars[-1]
                

        grad, weights = grads.sum(-1).view_as(hrepr), weights.sum(-1)
        self.singulars, self.weights = singulars, weights

        hrepr.backward(grad)
    
        if self.scale_decoder_grad is True:
            for i, decoder in enumerate(decoders.values()):
                for p in decoder.parameters():
                    if p.grad is not None:
                        p.grad.mul_(weights[i])