import torch
import torch.nn as nn

from geotorch import orthogonal

import utils as u


def rotate(points, rotation, total_size):
    if total_size != points.size(-1):
        points_lo, points_hi = points[:, :rotation.size(1)], points[:, rotation.size(1):]
        point_lo = torch.einsum('ij,bj->bi', rotation, points_lo)
        return torch.cat((point_lo, points_hi), dim=-1)
    else:
        return torch.einsum('ij,bj->bi', rotation, points)


def rotate_back(points, rotation, total_size):
    return rotate(points, rotation.t(), total_size)


def get_grad(tensor):
    value = getattr(tensor, 'grad', None)
    return 0. if value is None else value


class RotoGrad(nn.Module):
    _id = 0

    def __init__(self, num_tasks, input_size, *args):
        super(RotoGrad, self).__init__()

        # Parameterize rotations so we can run unconstrained optimization
        for i in range(num_tasks):
            self.register_parameter(f'rotation_{i}', nn.Parameter(torch.eye(input_size), requires_grad=True))
            orthogonal(self, f'rotation_{i}', triv='expm')  # uses exponential map (alternative: cayley)

        # Parameters
        self.num_tasks = num_tasks
        self.input_size = input_size
        self.alpha = 1.
        self.my_id = RotoGrad._id
        self.coop = False

        self.grads = [None for _ in range(num_tasks)]
        self.R_hook1 = [None for _ in range(num_tasks)]
        self.R_hook2 = [None for _ in range(num_tasks)]

        RotoGrad._id += 1

    @property
    def rotation(self):
        return [getattr(self, f'rotation_{i}') for i in range(self.num_tasks)]

    def __len__(self):
        return self.num_tasks

    def __getitem__(self, item):
        class RotateModule(nn.Module):
            def __init__(self, parent):
                super().__init__()

                self.parent = [parent]  # Dirty trick to don't register parameters

            def hook(self, g):
                self.p.grads[item] = g.clone()

            @property
            def p(self):
                return self.parent[0]

            @property
            def R(self):
                return self.p.rotation[item]

            def rotate(self, z):
                return rotate(z, self.R, self.p.input_size)

            def rotate_back(self, z):
                return rotate_back(z, self.R, self.p.input_size)

            def forward(self, z):
                R = self.R.clone().detach()
                new_z = rotate(z, R, self.p.input_size)
                if self.p.training:
                    new_z.register_hook(self.hook)

                return new_z

        return RotateModule(self)

    def callback(self, losses, grads, input):
        old_grads = grads  # these grads are already rotated, we have to recover the originals
        grads = self.grads

        # Compute the reference vector
        mean_grad = sum([g for g in old_grads]).detach().clone() / len(grads)
        mean_norm = mean_grad.norm(p=2)
        old_grads2 = [g * u.divide(mean_norm, g.norm(p=2)) for g in old_grads]
        mean_grad = sum([g for g in old_grads2]).detach().clone() / len(grads)

        for i, grad in enumerate(grads):
            R = self.rotation[i]
            loss_rotograd = rotate(mean_grad, R, self.input_size) - grad
            loss_rotograd = torch.einsum('bi,bi->b', loss_rotograd, loss_rotograd)
            loss_rotograd.mean().backward()

        return old_grads
