import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import norm, CUDA
import math


class Anchored(object):

    def reset_anchor(self, n_anchor, data_per_anchor):
        raise NotImplementedError()


    def reg_bias_align(self):
        raise NotImplementedError()


class ConvTranspose2d(nn.ConvTranspose2d):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.in_shape = None
        self.out_shape = None



class Linear(nn.Linear, Anchored):

    def __init__(self, *args, anchor_scale=1., adaptive_scale=False, n_anchor=50, **kwargs):
        kwargs.setdefault('bias', False)
        super().__init__(*args, **kwargs)
        self.in_shape = (self.in_features,)
        self.out_shape = (self.out_features,)
        self.n_anchor = n_anchor
        self.anchor_idx = None;  self.data_per_anchor = None;  self.batch_size = None

        self.a = nn.Parameter(torch.Tensor(self.n_anchor, *self.out_shape))
        if adaptive_scale:
            scale = anchor_scale * math.sqrt(2) / math.sqrt(self.in_features)
        else:
            scale = anchor_scale
        nn.init.normal_(self.a, std=scale)


    def reset_anchor(self, anchor_idx, data_per_anchor):
        self.anchor_idx = anchor_idx
        self.data_per_anchor = data_per_anchor
        self.batch_size = self.n_anchor * self.data_per_anchor



    # TODO: forward hook doesn't work due to recursion
    def forward(self, z, no_a=False):
        if no_a:
            return super().forward(z)
        else:
            assert self.anchor_idx is not None
            a_batch = self.a[self.anchor_idx].unsqueeze(1).repeat(1, self.data_per_anchor, 1).view(-1, *self.out_shape)
            return super().forward(z) + a_batch


    def reg_bias_align(self):
        zeta = F.linear(self.a, self.weight.detach().t())
        reg = torch.sum( torch.var(zeta, dim=0) )

        return reg

    def pinverse(self, x, n_steps=100, lr=0.01, exact=True):
        if exact:
            z = torch.mm(x, torch.pinverse(self.weight).t())
        else:
            z = (0.1 * torch.randn(x.shape[0], self.in_features, device=CUDA())).requires_grad_(True)
            prj_opt = torch.optim.LBFGS([z], lr=lr)
            lr_sched = torch.optim.lr_scheduler.ReduceLROnPlateau(prj_opt, factor=0.5)
            for i in range(n_steps):
                dist = torch.mean( torch.norm(self(z, no_a=True) - x.detach(), dim=1) )
                prj_opt.zero_grad()
                dist.backward()
                prj_opt.step(lambda: dist)
                lr_sched.step(dist)
        return z.detach()


    def set_bias(self, **pinv_kwargs):
        prototype_anchor = self.a.mean(dim=0, keepdim=True).detach()
        self.b = -self.pinverse(prototype_anchor, **pinv_kwargs)


class Sequential(nn.Sequential):

    def __init__(self, *args, in_shape=None, out_shape=None):
        super().__init__(*args)
        self.in_shape = in_shape
        self.out_shape = out_shape

        for name, module in self.named_children():
            module.name = f"{name} {module.__class__.__name__}"

    def infer_shapes(self):
        z_dummy = torch.zeros(1, *self.in_shape, device=CUDA())

        for l in self._modules.values():
            if not hasattr(l, 'in_shape') or not hasattr(l, 'out_shape'):
                continue
            if l.in_shape is None:
                l.in_shape = z_dummy.shape[1:]
            else:
                assert l.in_shape == z_dummy.shape[1:]

            try:
                z_dummy = l(z_dummy, no_a=True)
            except TypeError:
                z_dummy = l(z_dummy)

            if l.out_shape is None:
                l.out_shape = z_dummy.shape[1:]
            else:
                assert l.out_shape == z_dummy.shape[1:]
        assert self.out_shape == z_dummy.shape[1:]



class Reshape(nn.Module):

    def __init__(self, in_shape, out_shape):
        super().__init__()
        self.in_shape = in_shape
        self.out_shape = out_shape


    def forward(self, z):
        return z.view(-1, *self.out_shape)


    def inverse(self, x):
        return x.view(-1, *self.in_shape)


class LeakyReLU(nn.LeakyReLU):

    def inverse(self, x):
        return F.leaky_relu(x, 1./self.negative_slope)



class Tanh(nn.Tanh):

    def inverse(self, x, eps=1e-8):
        return 0.5*( torch.log(1+x+eps) - torch.log(1-x+eps) )



class BatchNorm1d(nn.BatchNorm1d):

    def inverse(self, x):
        assert not self.training
        assert len(x.shape) == 2 # Only support linear
        mean = self.running_mean
        var = self.running_var
        bcst_shp = (1, self.num_features)
        if self.affine:
            return (x - self.bias.view(*bcst_shp)) / self.weight.view(*bcst_shp) * torch.sqrt(
                var.view(*bcst_shp) + self.eps) + mean.view(*bcst_shp)
        else:
            return x * torch.sqrt(var.view(*bcst_shp) + self.eps) + mean.view(*bcst_shp)


class BatchNorm2d(nn.BatchNorm2d):

    def inverse(self, x):
        assert not self.training
        mean = self.running_mean
        var = self.running_var
        bcst_shp = (1, self.num_features, 1, 1)
        if self.affine:
            return (x - self.bias.view(*bcst_shp)) / self.weight.view(*bcst_shp) * torch.sqrt(
                var.view(*bcst_shp) + self.eps) + mean.view(*bcst_shp)
        else:
            return x * torch.sqrt(var.view(*bcst_shp) + self.eps) + mean.view(*bcst_shp)


