import torch.nn as nn
import torch as th
import numpy as np
from torch.autograd import Function
import nn as nn_modules
import utils
from nn.residual import ResidualBlock, SkipConnection
from nn.encoder import AggressiveDownConv
from nn.encoder import AggressiveConvTo1x1
from nn.decoder import AggressiveUpConv
from utils.utils import LambdaModule, ForcedAlpha, PrintShape
from nn.eprop_gate_l0rd import EpropGateL0rd
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce

from typing import Union, Tuple



class HyperFlow(nn.Module):
    def __init__(
        self, 
        mlp_layers: int,
        mlp_hidden: int
    ):
        super(HyperFlow, self).__init__()
        self.mlp_layers = mlp_layers
        self.mlp_hidden = mlp_hidden

        self.alpha = nn.Parameter(th.zeros(1)+1e-16)

    def num_weights(self):
        mlp_layers = self.mlp_layers
        mlp_hidden = self.mlp_hidden

        return 4 * mlp_hidden + (mlp_layers - 1) * mlp_hidden**2 + mlp_hidden * mlp_layers + 2
        
    def forward(self, input: th.Tensor, weights):

        size       = input.shape[-2:]
        batch_size = input.shape[0]
        mlp_layers = self.mlp_layers
        mlp_hidden = self.mlp_hidden

        grid_x = th.arange(end=size[0], device=input.device)
        grid_y = th.arange(end=size[1], device=input.device)

        grid_x = (grid_x / (size[0]-1)) * 2 - 1
        grid_y = (grid_y / (size[1]-1)) * 2 - 1

        grid_x = grid_x.view(1, 1, -1, 1).expand(batch_size, 1, *size).clone()
        grid_y = grid_y.view(1, 1, 1, -1).expand(batch_size, 1, *size).clone()

        base_grid = th.cat((grid_y, grid_x), dim=1)
        grid      = base_grid.reshape(1, -1, *size)
        
        w_in = weights[:,:2*mlp_hidden].reshape(mlp_hidden * batch_size, 2, 1, 1)
        b_in = weights[:,2*mlp_hidden:3*mlp_hidden].reshape(mlp_hidden * batch_size)
        w_in = w_in * np.sqrt(6) / (np.sqrt(2 + mlp_hidden))
        out  = th.tanh(nn.functional.conv2d(grid, w_in, bias=b_in, groups=batch_size))
        
        for l in range(mlp_layers - 1):
            start_index = 3*mlp_hidden + l * (mlp_hidden**2 + mlp_hidden)
            end_index   = start_index + mlp_hidden**2

            w_hidden = weights[:,start_index:end_index]
            w_hidden = w_hidden.reshape(mlp_hidden*batch_size, mlp_hidden, 1, 1)
            b_hidden = weights[:,end_index:end_index+mlp_hidden].reshape(mlp_hidden * batch_size)
            w_hidden = w_hidden * np.sqrt(6) / (np.sqrt(2*mlp_hidden))

            out = th.tanh(nn.functional.conv2d(out, w_hidden, bias=b_hidden, groups=batch_size))

        start_index = 3*mlp_hidden + (mlp_layers - 1) * (mlp_hidden**2 + mlp_hidden)
        end_index   = start_index + mlp_hidden*2

        w_out = weights[:,start_index:end_index].reshape(2*batch_size, mlp_hidden, 1, 1)
        b_out = weights[:,end_index:].reshape(batch_size*2)
        w_out = w_out * np.sqrt(6) / (np.sqrt(2 + mlp_hidden))

        raw_flow = nn.functional.conv2d(out, w_out, bias=b_out, groups=batch_size)
        raw_flow = raw_flow.reshape(batch_size, 2, *size) * self.alpha
        flow     = base_grid + raw_flow/input.shape[-1]
        flow     = rearrange(flow, 'b c h w -> b h w c')

        return nn.functional.grid_sample(input, flow, align_corners=True), raw_flow

class FlowBackgroundEnhancer(nn.Module):
    def __init__(
        self, 
        input_size: Tuple[int, int], 
        img_channels: int, 
        level1_channels,
        latent_channels,
        batch_size,
        reg_lambda,
        vae_factor,
        mlp_layers,
        mlp_hidden,
        deepth
    ):
        super(FlowBackgroundEnhancer, self).__init__()

        latent_size = [input_size[0] // 16, input_size[1] // 16]
        self.input_size = input_size

        self.register_buffer('init', th.zeros(1).long())
        self.alpha = nn.Parameter(th.zeros(1)+1e-16)
        self.hyper_flow = HyperFlow(mlp_layers, mlp_hidden)

        self.level  = 1
        self.down_level2 = nn.Sequential(
            AggressiveDownConv(img_channels, level1_channels, alpha = 1e-16),
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(deepth)]
        )

        self.down_level1 = nn.Sequential(
            AggressiveDownConv(level1_channels, latent_channels, alpha = 1),
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)]
        )

        self.down_level0 = nn.Sequential(
            LambdaModule(lambda x: rearrange(x, '(n b) c h w -> n b c h w', n = 2)[1]),
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
        )

        self.flow = nn.Sequential(
            LambdaModule(lambda x: rearrange(x, '(n b) c h w -> b (n c) h w', n = 2)),
            *[ResidualBlock(latent_channels * 2, latent_channels * 2) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels * 2, latent_size * 2),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropGateL0rd(
                num_inputs  = latent_channels * 2,
                num_hidden  = latent_channels * 2,
                num_outputs = self.hyper_flow.num_weights(),
                batch_size  = batch_size, 
                reg_lambda  = reg_lambda
            ),
        )

        self.up_level0 = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
        )

        self.up_level1 = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveUpConv(latent_channels, level1_channels, alpha = 1),
        )

        self.up_level2 = nn.Sequential(
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(deepth)],
            AggressiveUpConv(level1_channels, img_channels, alpha = 1e-16),
        )

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels, latent_channels),
            SkipConnection(img_channels, level1_channels),
            SkipConnection(img_channels, img_channels),
        ])

        self.to_img = nn.ModuleList([
            SkipConnection(latent_channels, img_channels),
            SkipConnection(level1_channels, img_channels),
            SkipConnection(img_channels,    img_channels),
        ])

        self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10)

        self.register_buffer('latent', th.zeros((batch_size, latent_channels, *latent_size)), persistent=False)
        self.frames = None

    def get_openings(self):
        return self.flow[-1].openings.item()

    def get_init(self):
        return self.init.item()

    def step_init(self):
        self.init = self.init + 1

    def detach(self):
        self.latent = self.latent.detach()

    def reset_state(self):
        self.latent = th.zeros_like(self.latent)
        self.frames = None

    def set_level(self, level):
        self.level = level

    def encoder(self, input):
        latent = self.to_channels[self.level](input)

        if self.level >= 2:
            latent = self.down_level2(latent)

        if self.level >= 1:
            latent = self.down_level1(latent)

        flow_weights = self.flow(latent)
        latent       = self.down_level0(latent)

        return self.hyper_flow(latent, flow_weights)

    def get_last_latent_gird(self):
        return self.latent

    def decoder(self, latent):
        latent = self.up_level0(latent)

        if self.level >= 1:
            latent = self.up_level1(latent)

        if self.level >= 2:
            latent = self.up_level2(latent)

        return th.sigmoid(self.to_img[self.level](latent))

    def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None):

        if self.frames is None:
            self.frames = [input]*2
        else:
            self.frames = self.frames[1:] + [input]

        self.latent, raw_flow = self.encoder(rearrange(th.stack(self.frames), 'n b c h w -> (n b) c h w'))
        #print(f"raw_flow: {th.mean(raw_flow).item():.2e} +- {th.std(raw_flow).item():.2e}")

        mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
        mask = repeat(mask,      '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
        
        background = self.decoder(self.latent)

        if self.get_init() < 1:
            return mask, background, raw_flow

        if self.get_init() < 2:
            return mask, th.zeros_like(background), th.zeros_like(self.latent), background

        return mask, background, self.latent * self.alpha, background
