import torch
import torch.nn as nn


class Curvature(nn.Module):
 
    def __init__(self, theta_params):
        super(Curvature, self).__init__()
 
        self.Mf = nn.ParameterList([])
        self.Mi = nn.ParameterList([])
        self.Mo = nn.ParameterList([])
        for theta in theta_params:
            if len(theta.shape) == 1:
                self.Mf.append(None)
                self.Mi.append(None)
                self.Mo.append(nn.Parameter(torch.ones(theta.shape)))
            elif len(theta.shape) == 2:
                self.Mf.append(None)
                self.Mi.append(nn.Parameter(torch.eye(theta.shape[1])))
                self.Mo.append(nn.Parameter(torch.eye(theta.shape[0])))
            elif len(theta.shape) == 4:
                self.Mf.append(nn.Parameter(torch.eye(theta.shape[2] * theta.shape[3])))
                self.Mi.append(nn.Parameter(torch.eye(theta.shape[1])))
                self.Mo.append(nn.Parameter(torch.eye(theta.shape[0])))
 
    def warp_grads(self, gradients):
 
        warped_grads = []
 
        for i, grads in enumerate(gradients):
            if len(grads.shape) == 1:
                # gradients[i] = self.Mo[i] * grads
                warped_grads.append(self.Mo[i] * grads)
            elif len(grads.shape) == 2:
                g = grads
                g = torch.matmul(self.Mi[i], torch.transpose(g, 1, 0))
                g = torch.matmul(self.Mo[i], torch.transpose(g, 1, 0))
                # gradients[i] = g
                warped_grads.append(g)
            elif len(grads.shape) == 4:
                cout = grads.shape[0]
                cin = grads.shape[1]
                cd = grads.shape[2] * grads.shape[3]
                g = grads
                g = torch.matmul(self.Mf[i], torch.transpose(torch.reshape(g, (cout * cin, -1)), 1, 0))
                g = torch.matmul(self.Mi[i], torch.reshape(torch.transpose(g.view(cd, cout, cin), 2, 0), (cin, cout * cd)))
                g = torch.matmul(self.Mo[i], torch.reshape(torch.transpose(g.view(cin, cout, cd), 1, 0), (cout, cin * cd)))
                # gradients[i] = g.view(grads.shape)
                warped_grads.append(g.view(grads.shape))
 
        return warped_grads
