import torch
import torch.nn.functional as F

from .utils import conjugate_gradient
from .constants import GPU_MAYBE


def matrix_vector_product_linear(W, L, p):
    v = torch.einsum("bn,nm->bm", W.pow(2), L)
    v = v * p

    return v


def dot_product_linear(a, b):
    return a * b


def conv2d_forward_batch(input, weight, stride, padding):
    c_out, c_in, h_weight, w_weight = weight.shape
    batch_size, h_input, w_input = input.shape[0], input.shape[2], input.shape[3]

    h_out = (h_input + 2 * padding - (h_weight - 1) - 1) / stride + 1
    w_out = (w_input + 2 * padding - (w_weight - 1) - 1) / stride + 1
    h_out, w_out = int(h_out), int(w_out)

    input_unfold = torch.nn.functional.unfold(input, (h_weight, w_weight), padding=padding, stride=stride)
    input_patches = input_unfold.transpose(1, 2)
    output_patches = input_patches.matmul(weight.view(weight.size(0), -1).t())
    output = output_patches.transpose(1, 2).view(batch_size, c_out, h_out, w_out)

    return output, input_patches, output_patches


def conv2d_forward_obs(input, weight, stride, padding):
    batch_size, out_channels, in_channels, h_weight, w_weight = weight.shape
    h_input, w_input = input.shape[2], input.shape[3]

    h_out = (h_input + 2 * padding - (h_weight - 1) - 1) / stride + 1
    w_out = (w_input + 2 * padding - (w_weight - 1) - 1) / stride + 1
    h_out, w_out = int(h_out), int(w_out)

    input_unfold = torch.nn.functional.unfold(input, (h_weight, w_weight), padding=padding, stride=stride)
    blocks = input_unfold.shape[2]
    input_patches = input_unfold.view(batch_size, in_channels, h_weight * w_weight, blocks).transpose(2, 3) # bilk
    weight_view = weight.view(batch_size, out_channels, in_channels, h_weight * w_weight) # boik
    in_channel_patches = torch.einsum("boik,bilk->boil", weight_view, input_patches)
    output_patches = in_channel_patches.sum(dim=2) # bol
    output = output_patches.view(batch_size, out_channels, h_out, w_out)

    return output, input_patches, output_patches


class Conv2D_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, stride, padding):
        ctx.stride = stride
        ctx.padding = padding

        output, input_patches, output_patches = conv2d_forward_batch(input, weight, stride, padding)
        
        ctx.save_for_backward(input, weight, input_patches, output_patches)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        input, weight, input_patches, output_patches = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding
        out_channels, in_channels, h_weight, w_weight = weight.shape
        batch_size, h_input, w_input = input.shape[0], input.shape[2], input.shape[3]

        grad_out_unfold = torch.nn.functional.unfold(grad_out, kernel_size=(h_weight, w_weight), padding=h_weight-padding-1)
        grad_out_patches = grad_out_unfold.transpose(1, 2)
        rot_weight = torch.rot90(weight, k=2, dims=[2, 3]).transpose(0, 1)
        grad_input_patches = grad_out_patches.matmul(rot_weight.reshape(rot_weight.size(0), -1).t())
        grad_input = grad_input_patches.transpose(1, 2).view(batch_size, in_channels, h_input, w_input)

        grad_weight = torch.matmul(grad_out.view(grad_out.size(0), grad_out.size(1), -1), input_patches)
        grad_weight = grad_weight.sum(dim=0)
        grad_weight = grad_weight.view(weight.shape)

        return grad_input, grad_weight, None, None


class Conv2D(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=1,
        marker=None,
    ):
        super(Conv2D, self).__init__()
        self.kernel_size = kernel_size
        if isinstance(kernel_size, int):
            kernel_size = [kernel_size, kernel_size]

        self.size = [out_channels, in_channels, *kernel_size]
        self.weights = torch.nn.Parameter(
            torch.Tensor(size=self.size).uniform_(-0.1, 0.1)
        )

        self.stride = stride
        self.padding = padding
        self.marker = marker

    def forward(self, x):
        output = Conv2D_func.apply(x, self.weights, self.stride, self.padding)
        return output


class CustomConv2D(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        marker=None,
    ):
        super(CustomConv2D, self).__init__()
        self.kernel_size = kernel_size
        if isinstance(kernel_size, int):
            kernel_size = [kernel_size, kernel_size]

        self.size = [out_channels, in_channels, *kernel_size]
        self.weights = torch.nn.Parameter(
            torch.Tensor(size=self.size).uniform_(-0.1, 0.1)
        )

        self.stride = stride
        self.padding = padding
        self.marker = marker

    def forward(self, x):
        output = F.conv2d(x, self.weights, stride=self.stride, padding=self.padding)
        return output


def matrix_vector_product_conv2d(W, L, p):
    v = torch.einsum("bol,bilk->boik", p, W)
    v *= L.unsqueeze(0) # boik, oik -> boik
    v = torch.einsum("boik,bikl->bol", v, W.transpose(2, 3))

    return v


def dot_product_conv2d(a, b):
    return (a * b).sum(dim=2, keepdim=True)


class R2G2MeanFieldConv2D(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w_var, eps, stride, padding):
        ctx.stride = stride
        ctx.padding = padding

        w_std = w_var.sqrt()
        w = eps * w_std.unsqueeze(0)
        z, x_patches, in_channel_patches, z_patches = conv2d_forward_obs(x, w, stride, padding)
        ctx.save_for_backward(x, w_var, x_patches, in_channel_patches, z_patches)

        return z


class GaussianLinear(torch.nn.Module):
    def __init__(self, dim_in, dim_out, reparam="r2g2", fwd="lrt", marker=None):
        super(GaussianLinear, self).__init__()

        # Learnable parameters
        self.mean_params = torch.nn.Parameter(
            torch.Tensor(size=(dim_in, dim_out)).uniform_(-0.1, 0.1)
        )

        self.var_params = torch.nn.Parameter(
            torch.log(torch.exp(torch.ones(size=(dim_in, dim_out)) * 1e-4) - 1)
        )

        self.reparam = reparam
        self.fwd = fwd
        self.marker = marker

    def w_var(self):
        return F.softplus(self.var_params)

    def forward(self, x):
        w_mean = self.mean_params
        w_var = self.w_var()

        if self.training:
            if self.reparam == "rt":
                batch_size = x.shape[0]
                w_std = w_var.sqrt()
                eps = torch.empty([batch_size, *(w_std.size())], device=x.device).normal_(0.0, 1.0)
                w = w_mean.unsqueeze(0) + (eps * w_std.unsqueeze(0))
                z = torch.einsum("bnm,bn->bm", w, x)

            elif self.reparam == "lrt":
                z_mean = F.linear(x, w_mean.T)
                z_var = F.linear(x.pow(2), w_var.T)
                z_std = z_var.sqrt()
                eps = torch.empty(z_std.size(), device=x.device).normal_(0.0, 1.0)
                z = z_mean + z_std * eps

            elif self.reparam == "r2g2":
                if self.fwd == "rt":
                    with torch.no_grad():
                        # compute pre-activations
                        batch_size = x.shape[0]
                        x_detach = x.detach()
                        w_var_detach = w_var.detach()
                        w_std_detach = w_var_detach.sqrt()
                        eps = torch.empty([batch_size, *(w_std_detach.size())], device=x.device).normal_(0.0, 1.0)
                        w_std_eps = w_std_detach * eps
                        fwd_z_std_eps = torch.einsum("bnm,bn->bm", w_std_eps, x_detach)

                        # compute conditional eps for each (batch_size, dim_out)
                        r2g2_beta = conjugate_gradient(
                            W=x_detach,
                            V=w_var_detach,
                            b=fwd_z_std_eps,
                            matrix_vector_product_function=matrix_vector_product_linear,
                            dot_product_function=dot_product_linear,
                            iters=1,
                        )
                        r2g2_eps = w_std_detach * r2g2_beta.unsqueeze(1) # bnm, bm -> bnm
                        r2g2_eps *= x_detach.unsqueeze(2)
                
                elif self.fwd == "lrt":
                    with torch.no_grad():
                        # compute pre-activations
                        x_detach = x.detach()
                        w_var_detach = w_var.detach()
                        w_std_detach = w_var_detach.sqrt()
                        z_var_detach = F.linear(x_detach.pow(2), w_var_detach.T)
                        z_std_detach = z_var_detach.sqrt()
                        eps = torch.empty(z_std_detach.size(), device=x.device).normal_(0.0, 1.0)
                        fwd_z_std_eps = z_std_detach * eps

                        # compute conditional eps for each (batch_size, dim_out)
                        r2g2_beta = conjugate_gradient(
                            W=x_detach,
                            V=w_var_detach,
                            b=fwd_z_std_eps,
                            matrix_vector_product_function=matrix_vector_product_linear,
                            dot_product_function=dot_product_linear,
                            iters=1,
                        )
                        r2g2_eps = w_std_detach * r2g2_beta.unsqueeze(1) # bnm, bm -> bnm
                        r2g2_eps *= x_detach.unsqueeze(2)
                
                w_std = w_var.sqrt()
                r2g2_w = w_std.unsqueeze(0) * r2g2_eps
                r2g2_z_std_eps = torch.einsum("bnm,bn->bm", r2g2_w, x)
                z_std_eps = (fwd_z_std_eps - r2g2_z_std_eps).detach() + r2g2_z_std_eps # stop_gradients

                z_mean = F.linear(x, w_mean.T)
                z = z_mean + z_std_eps

            else:
                raise NotImplementedError

        else:
            w_std = w_var.sqrt()
            eps = torch.empty(w_std.size(), device=x.device).normal_(0.0, 1.0)
            w = w_mean + w_std * eps
            z = F.linear(x, w.T)

        return z

    def get_params(self):
        w_mean = self.mean_params
        w_var = self.w_var()
        w_std = w_var.sqrt()

        w_mean = w_mean.reshape(shape=[-1])
        w_std = w_std.reshape(shape=[-1])

        return torch.stack([w_mean, w_std], dim=-1)


class GaussianConv2d(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding=1,
        reparam="r2g2",
        marker=None,
    ):
        super(GaussianConv2d, self).__init__()
        self.kernel_size = kernel_size
        if isinstance(kernel_size, int):
            kernel_size = [kernel_size, kernel_size]

        self.size = [out_channels, in_channels, *kernel_size]
        self.mean_params = torch.nn.Parameter(
            torch.Tensor(size=self.size).uniform_(-0.1, 0.1)
        )

        self.var_params = torch.nn.Parameter(
            torch.log(torch.exp(torch.ones(size=self.size) * 1e-4) - 1)
        )

        self.stride = 1
        self.padding = padding
        self.reparam = reparam
        self.marker = marker

    def w_var(self):
        return F.softplus(self.var_params)

    def forward(self, x):
        w_mean = self.mean_params
        w_var = self.w_var()

        if self.training:
            if self.reparam == "rt":
                batch_size = x.shape[0]
                w_std = w_var.sqrt()
                eps = torch.empty([batch_size, *(w_std.size())], device=x.device).normal_(0.0, 1.0)
                w = w_mean.unsqueeze(0) + (eps * w_std.unsqueeze(0))
                z, _, _ = conv2d_forward_obs(x, w, self.stride, self.padding)

            elif self.reparam == "r2g2":
                z_mean, _, _ = conv2d_forward_batch(x, w_mean, self.stride, self.padding)
                batch_size = x.shape[0]

                with torch.no_grad():
                    # compute pre-activations
                    x_detach = x.detach()
                    w_var_detach = w_var.detach()
                    eps = torch.empty([batch_size, *(w_var.size())], device=x.device).normal_(0.0, 1.0)
                    w_detach = eps * w_var_detach.sqrt().unsqueeze(0)
                    fwd_z_std_eps, x_patches, output_patches = conv2d_forward_obs(
                        x_detach,
                        w_detach,
                        self.stride,
                        self.padding
                    )

                    # compute conditional eps for each (batch_size, out_channel, in_channel)
                    out_channels, in_channels, h_weight, w_weight = w_var.shape
                    blocks = x_patches.shape[2]
                    iters = min(blocks, in_channels * h_weight * w_weight)
                    w_var_view = w_var_detach.view(out_channels, in_channels, h_weight * w_weight)
                    r2g2_beta = conjugate_gradient(
                        W=x_patches,
                        V=w_var_view,
                        b=output_patches,
                        matrix_vector_product_function=matrix_vector_product_conv2d,
                        dot_product_function=dot_product_conv2d,
                        iters=iters,
                    ) # bol
                    r2g2_eps_view = torch.einsum("bol,bilk->boik", r2g2_beta, x_patches)
                    r2g2_eps_view *= w_var_view.sqrt().unsqueeze(0)
                    r2g2_eps = r2g2_eps_view.view(batch_size, out_channels, in_channels, h_weight, w_weight)

                # compute pre-activations with r2g2_eps
                w_std = w_var.sqrt()
                r2g2_w = r2g2_eps * w_std.unsqueeze(0)
                r2g2_z_std_eps, _, _ = conv2d_forward_obs(x, r2g2_w, self.stride, self.padding)
                z_std_eps = (fwd_z_std_eps - r2g2_z_std_eps).detach() + r2g2_z_std_eps # stop_gradients
                
                z = z_mean + z_std_eps

            else:
                raise NotImplementedError
        
        else:
            w_std = w_var.sqrt()
            eps = torch.empty(w_std.size(), device=x.device).normal_(0.0, 1.0)
            w = w_mean + w_std * eps
            z, _, _ = conv2d_forward_batch(x, w, self.stride, self.padding)

        return z

    def get_params(self):
        w_mean = self.mean_params
        w_var = self.w_var()
        w_std = w_var.sqrt()

        w_mean = w_mean.reshape(shape=[-1])

        w_std = w_std.reshape(shape=[-1])

        return torch.stack([w_mean, w_std], dim=-1)


class MnistBNN(torch.nn.Module):
    def __init__(self, reparam="r2g2", fwd="lrt"):
        super().__init__()

        layers = [
            torch.nn.Flatten(),
            GaussianLinear(dim_in=28*28, dim_out=1024, reparam=reparam, fwd=fwd, marker="bottom"),
            torch.nn.ReLU(),
            GaussianLinear(dim_in=1024, dim_out=1024, reparam=reparam, fwd=fwd),
            torch.nn.ReLU(),
            GaussianLinear(dim_in=1024, dim_out=10, reparam=reparam, fwd=fwd, marker="top"),
        ]

        self.net = torch.nn.Sequential(*layers)

        self.gpu_maybe = GPU_MAYBE
        self.net.to(self.gpu_maybe)

    def forward(self, x):
        return self.net(x)


class CifarBNN(torch.nn.Module):
    def __init__(self, reparam="r2g2"):
        super().__init__()

        layers = [
            # conv1
            GaussianConv2d(
                in_channels=3,
                out_channels=64,
                kernel_size=3,
                padding=1,
                reparam="rt",
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            # conv2
            GaussianConv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=3,
                padding=1,
                reparam="rt",
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            # conv3 + conv4
            GaussianConv2d(
                in_channels=128,
                out_channels=256,
                kernel_size=3,
                padding=1,
                reparam="rt",
            ),
            torch.nn.ReLU(),
            GaussianConv2d(
                in_channels=256,
                out_channels=256,
                kernel_size=3,
                padding=1,
                reparam="rt",
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            # conv5 + conv6
            GaussianConv2d(
                in_channels=256,
                out_channels=512,
                kernel_size=3,
                padding=1,
                reparam=reparam,
                marker="bottom",
            ),
            torch.nn.ReLU(),
            GaussianConv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                padding=1,
                reparam=reparam,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            # conv7 + conv8
            GaussianConv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                padding=1,
                reparam=reparam,
            ),
            torch.nn.ReLU(),
            GaussianConv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                padding=1,
                reparam=reparam,
                marker="top",
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Flatten(),
            # linear1
            GaussianLinear(dim_in=512, dim_out=512, reparam="lrt"),
            torch.nn.ReLU(),
            # linear2
            GaussianLinear(dim_in=512, dim_out=512, reparam="lrt"),
            torch.nn.ReLU(),
            # linear3
            GaussianLinear(dim_in=512, dim_out=10, reparam="lrt"),
        ]

        self.net = torch.nn.Sequential(*layers)

        self.gpu_maybe = GPU_MAYBE
        self.net.to(self.gpu_maybe)

    def forward(self, x):
        return self.net(x)


class VGGCNN(torch.nn.Module):
    def __init__(self, classes=10):
        super().__init__()

        layers = [
            # conv1
            Conv2D(
                in_channels=3,
                out_channels=64,
                kernel_size=3,
                padding=1,
                marker="bottom",
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            # conv2
            Conv2D(
                in_channels=64,
                out_channels=128,
                kernel_size=3,
                padding=1,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            # conv3 + conv4
            Conv2D(
                in_channels=128,
                out_channels=256,
                kernel_size=3,
                padding=1,
            ),
            torch.nn.ReLU(),
            Conv2D(
                in_channels=256,
                out_channels=256,
                kernel_size=3,
                padding=1,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            # conv5 + conv6
            Conv2D(
                in_channels=256,
                out_channels=512,
                kernel_size=3,
                padding=1,
            ),
            torch.nn.ReLU(),
            Conv2D(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                padding=1,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            # conv7 + conv8
            Conv2D(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                padding=1,
            ),
            torch.nn.ReLU(),
            Conv2D(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                padding=1,
                marker="top",
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Flatten(),
            # linear1
            torch.nn.Linear(in_features=512, out_features=512),
            torch.nn.ReLU(),
            # linear2
            torch.nn.Linear(in_features=512, out_features=512),
            torch.nn.ReLU(),
            # linear3
            torch.nn.Linear(in_features=512, out_features=classes),
        ]

        self.net = torch.nn.Sequential(*layers)

        self.gpu_maybe = GPU_MAYBE
        self.net.to(self.gpu_maybe)

    def forward(self, x):
        return self.net(x)


class BNNVarEstimator():
    def __init__(self, model, marker):
        device = model.gpu_maybe

        self.count = 0
        self.first_moment_params = []
        self.second_moment_params = []
        self.marker = marker

        for module in model.net.modules():
            if hasattr(module, "marker") and module.marker == self.marker and hasattr(module, "var_params"):
                var_params = module.var_params
                self.first_moment_params += [torch.zeros_like(var_params.data, device=device)]
                self.second_moment_params += [torch.zeros_like(var_params.data, device=device)]


    def update(self, model):
        x = []
        x_squared = []

        for module in model.net.modules():
            if hasattr(module, "marker") and module.marker == self.marker and hasattr(module, "var_params"):
                var_params = module.var_params
                x += [var_params.grad]
                x_squared += [var_params.grad**2]

        assert len(x) == len(self.first_moment_params)
        for i in range(len(x)):
            self.first_moment_params[i] = x[i] + self.first_moment_params[i]
            self.second_moment_params[i] = x_squared[i] + self.second_moment_params[i]
        self.count += 1
    
    def get_var(self):
        variances = []
        for i in range(len(self.second_moment_params)):
            variances.append((self.second_moment_params[i]/self.count - (self.first_moment_params[i]/self.count)**2).flatten())
        return torch.cat(variances).mean().item()
    
    def get_mean(self):
        means = []
        for i in range(len(self.first_moment_params)):
            means.append((self.first_moment_params[i]/self.count).flatten())
        return torch.cat(means).mean().item()
